Info
- Title: InGAN: Capturing and Remapping the “DNA” of a Natural Image
- Task: Image Genration
- Author: Assaf Shocher, Shai Bagon, Phillip Isola, Michal Irani
- Date: Dec. 2018
- Arxiv: 1812.00231
- Published: ICCV 2019
Abstract
Generative Adversarial Networks (GANs) typically learn a distribution of images in a large image dataset, and are then able to generate new images from this distribution. However, each natural image has its own internal statistics, captured by its unique distribution of patches. In this paper we propose an “Internal GAN” (InGAN) - an image-specific GAN - which trains on a single input image and learns its internal distribution of patches. It is then able to synthesize a plethora of new natural images of significantly different sizes, shapes and aspect-ratios - all with the same internal patch-distribution (same “DNA”) as the input image. In particular, despite large changes in global size/shape of the image, all elements inside the image maintain their local size/shape. InGAN is fully unsupervised, requiring no additional data other than the input image itself. Once trained on the input image, it can remap the input to any size or shape in a single feedforward pass, while preserving the same internal patch distribution. InGAN provides a unified framework for a variety of tasks, bridging the gap between textures and natural images.
Motivation & Design
Overview
InGAN consists of a Generator G that retargets input x to output y whose size/shape is determined by a geometric transformation T (top left). A multiscale discriminator D learns to discriminate the patch statistics of the fake output y from the true patch statistics of the input image (right). Additionally, we take advantage of G’s automorphism to reconstruct the input back from y using G and the inverse transformation T −1 (bottom left).
The formulation aims to achieve two properties:
-
matching distributions: The distribution of patches, across scales, in the synthesized image, should match that distribution in the original input image. This property is a generalization of both the Coherence and Completeness objectives.
-
localization: The elements’ locations in the generated image should generally match their relative locations in the original input image.
Shape-flexible Generator
The desired geometric transformation for the output shape T is treated as an additional input that is fed to G for every forward pass. A parameter-free transformation layer geometrically transforms the feature map to the desired output shape. Making the transformation layer parameter-free allows training G once to transform x to any size, shape or aspect ratio at test time.
Multi-scale Patch Discriminator
InGAN uses a multi-scale D. This feature is significant: A single scale discriminator can only capture patch statistics of a specific size. Using a multiscale D matches the patch distribution over a range of patch sizes, capturing both fine-grained details as well as coarse structures in the image. At each scale, the discriminator is rather simple: it consists of just four conv-layers with the first one strided. Weights are not shared between different scale discriminators.
Experiments & Ablation Study
Applications on image retargetting and texture synthesis can be found at the project page.
Code
Model
class InGAN:
def __init__(self, conf):
# Acquire configuration
self.conf = conf
self.cur_iter = 0
self.max_iters = conf.max_iters
# Define input tensor
self.input_tensor = torch.FloatTensor(1, 3, conf.input_crop_size, conf.input_crop_size).cuda()
self.real_example = torch.FloatTensor(1, 3, conf.output_crop_size, conf.output_crop_size).cuda()
# Define networks
self.G = networks.Generator(conf.G_base_channels, conf.G_num_resblocks, conf.G_num_downscales, conf.G_use_bias,
conf.G_skip)
self.D = networks.MultiScaleDiscriminator(conf.output_crop_size, self.conf.D_max_num_scales,
self.conf.D_scale_factor, self.conf.D_base_channels)
self.GAN_loss_layer = networks.GANLoss()
self.Reconstruct_loss = networks.WeightedMSELoss(use_L1=conf.use_L1)
self.RandCrop = networks.RandomCrop([conf.input_crop_size, conf.input_crop_size], must_divide=conf.must_divide)
self.SwapCrops = networks.SwapCrops(conf.crop_swap_min_size, conf.crop_swap_max_size)
# Define loss function
self.criterionGAN = self.GAN_loss_layer.forward
self.criterionReconstruction = self.Reconstruct_loss.forward
# Keeping track of losses- prepare tensors
self.losses_G_gan = torch.FloatTensor(conf.print_freq).cuda()
self.losses_D_real = torch.FloatTensor(conf.print_freq).cuda()
self.losses_D_fake = torch.FloatTensor(conf.print_freq).cuda()
self.losses_G_reconstruct = torch.FloatTensor(conf.print_freq).cuda()
if self.conf.reconstruct_loss_stop_iter > 0:
self.losses_D_reconstruct = torch.FloatTensor(conf.print_freq).cuda()
# Initialize networks
self.G.apply(networks.weights_init)
self.D.apply(networks.weights_init)
# Initialize optimizers
self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=conf.g_lr, betas=(conf.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=conf.d_lr, betas=(conf.beta1, 0.999))
# Learning rate scheduler
# First define linearly decaying functions (decay starts at a special iter)
start_decay = conf.lr_start_decay_iter
end_decay = conf.max_iters
# def lr_function(n_iter):
# return 1 - max(0, 1.0 * (n_iter - start_decay) / (conf.max_iters - start_decay))
lr_function = LRPolicy(start_decay, end_decay)
# Define learning rate schedulers
self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(self.optimizer_G, lr_function)
self.lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(self.optimizer_D, lr_function)
def train_g(self):
# Zeroize gradients
self.optimizer_G.zero_grad()
self.optimizer_D.zero_grad()
# Determine output size of G (dynamic change)
output_size, random_affine = random_size(orig_size=self.input_tensor.shape[2:],
curriculum=self.conf.curriculum,
i=self.cur_iter,
iter_for_max_range=self.conf.iter_for_max_range,
must_divide=self.conf.must_divide,
min_scale=self.conf.min_scale,
max_scale=self.conf.max_scale,
max_transform_magniutude=self.conf.max_transform_magnitude)
# Add noise to G input for better generalization (make it ignore the 1/255 binning)
self.input_tensor_noised = self.input_tensor + (torch.rand_like(self.input_tensor) - 0.5) * 2.0 / 255
# Generator forward pass
self.G_pred = self.G.forward(self.input_tensor_noised, output_size=output_size, random_affine=random_affine)
# Run generator result through discriminator forward pass
self.scale_weights = get_scale_weights(i=self.cur_iter,
max_i=self.conf.D_scale_weights_iter_for_even_scales,
start_factor=self.conf.D_scale_weights_sigma,
input_shape=self.G_pred.shape[2:],
min_size=self.conf.D_min_input_size,
num_scales_limit=self.conf.D_max_num_scales,
scale_factor=self.conf.D_scale_factor)
d_pred_fake = self.D.forward(self.G_pred, self.scale_weights)
# If reconstruction-loss is used, run through decoder to reconstruct, then calculate reconstruction loss
if self.conf.reconstruct_loss_stop_iter > self.cur_iter:
self.reconstruct = self.G.forward(self.G_pred, output_size=self.input_tensor.shape[2:], random_affine=-random_affine)
self.loss_G_reconstruct = self.criterionReconstruction(self.reconstruct, self.input_tensor, self.loss_mask)
# Calculate generator loss, based on discriminator prediction on generator result
self.loss_G_GAN = self.criterionGAN(d_pred_fake, is_d_input_real=True)
# Generator final loss
# Weighted average of the two losses (if indicated to use reconstruction loss)
if self.conf.reconstruct_loss_stop_iter < self.cur_iter:
self.loss_G = self.loss_G_GAN
else:
self.loss_G = (self.conf.reconstruct_loss_proportion * self.loss_G_reconstruct + self.loss_G_GAN)
# Calculate gradients
# Note that the gradients are propagated from the loss through discriminator and then through generator
self.loss_G.backward()
# Update weights
# Note that only generator weights are updated (by definition of the G optimizer)
self.optimizer_G.step()
# Extra training for the inverse G. The difference between this and the reconstruction is the .detach() which
# makes the training only for the inverse G and not for regular G.
if self.cur_iter > self.conf.G_extra_inverse_train_start_iter:
for _ in range(self.conf.G_extra_inverse_train):
self.optimizer_G.zero_grad()
self.inverse = self.G.forward(self.G_pred.detach(), output_size=self.input_tensor.shape[2:], random_affine=-random_affine)
self.loss_G_inverse = (self.criterionReconstruction(self.inverse, self.input_tensor, self.loss_mask) *
self.conf.G_extra_inverse_train_ratio)
self.loss_G_inverse.backward()
self.optimizer_G.step()
# Update learning rate scheduler
self.lr_scheduler_G.step()
def train_d(self):
# Zeroize gradients
self.optimizer_D.zero_grad()
# Adding noise to D input to prevent overfitting to 1/255 bins
real_example_with_noise = self.real_example + (torch.rand_like(self.real_example[-1]) - 0.5) * 2.0 / 255.0
# Discriminator forward pass over real example
self.d_pred_real = self.D.forward(real_example_with_noise, self.scale_weights)
# Adding noise to D input to prevent overfitting to 1/255 bins
# Note that generator result is detached so that gradients are not propagating back through generator
g_pred_with_noise = self.G_pred.detach() + (torch.rand_like(self.G_pred) - 0.5) * 2.0 / 255
# Discriminator forward pass over generated example example
self.d_pred_fake = self.D.forward(g_pred_with_noise, self.scale_weights)
# Calculate discriminator loss
self.loss_D_fake = self.criterionGAN(self.d_pred_fake, is_d_input_real=False)
self.loss_D_real = self.criterionGAN(self.d_pred_real, is_d_input_real=True)
self.loss_D = (self.loss_D_real + self.loss_D_fake) * 0.5
# Calculate gradients
# Note that gradients are not propagating back through generator
# noinspection PyUnresolvedReferences
self.loss_D.backward()
# Update weights
# Note that only discriminator weights are updated (by definition of the D optimizer)
self.optimizer_D.step()
# Update learning rate scheduler
self.lr_scheduler_D.step()
def train_one_iter(self, cur_iter, input_tensors):
# Set inputs as random crops
input_crops = []
mask_crops = []
real_example_crops = []
mask_flag = False
for input_tensor in input_tensors:
real_example_crops += self.RandCrop.forward([input_tensor])
if np.random.rand() < self.conf.crop_swap_probability:
swapped_input_tensor, loss_mask = self.SwapCrops.forward(input_tensor)
[input_crop, mask_crop] = self.RandCrop.forward([swapped_input_tensor, loss_mask])
input_crops.append(input_crop)
mask_crops.append(mask_crop)
mask_flag = True
else:
input_crops.append(real_example_crops[-1])
self.input_tensor = torch.cat(input_crops)
self.real_example = torch.cat(real_example_crops)
self.loss_mask = torch.cat(mask_crops) if mask_flag else None
# Update current iteration
self.cur_iter = cur_iter
# Run a single forward-backward pass on the model and update weights
# One global iteration includes several iterations of generator and several of discriminator
# (not necessarily equal)
# noinspection PyRedeclaration
for _ in range(self.conf.G_iters):
self.train_g()
# noinspection PyRedeclaration
for _ in range(self.conf.D_iters):
self.train_d()
# Accumulate stats
# Accumulating as cuda tensors is much more efficient than passing info from GPU to CPU at every iteration
self.losses_G_gan[cur_iter % self.conf.print_freq] = self.loss_G_GAN.item()
self.losses_D_fake[cur_iter % self.conf.print_freq] = self.loss_D_fake.item()
self.losses_D_real[cur_iter % self.conf.print_freq] = self.loss_D_real.item()
if self.conf.reconstruct_loss_stop_iter > self.cur_iter:
self.losses_G_reconstruct[cur_iter % self.conf.print_freq] = self.loss_G_reconstruct.item()
Generator
class Generator(nn.Module):
""" Architecture of the Generator, uses res-blocks """
def __init__(self, base_channels=64, n_blocks=6, n_downsampling=3, use_bias=True, skip_flag=True):
super(Generator, self).__init__()
# Determine whether to use skip connections
self.skip = skip_flag
# Entry block
# First conv-block, no stride so image dims are kept and channels dim is expanded (pad-conv-norm-relu)
self.entry_block = nn.Sequential(nn.ReflectionPad2d(3),
nn.utils.spectral_norm(nn.Conv2d(3, base_channels, kernel_size=7, bias=use_bias)),normalization_layer(base_channels),nn.LeakyReLU(0.2, True))
# Geometric transformation
self.geo_transform = GeoTransform()
# Downscaling
# A sequence of strided conv-blocks. Image dims shrink by 2, channels dim expands by 2 at each block
self.downscale_block = RescaleBlock(n_downsampling, 0.5, base_channels, True)
# Bottleneck
# A sequence of res-blocks
bottleneck_block = []
for _ in range(n_blocks):
# noinspection PyUnboundLocalVariable
bottleneck_block += [ResnetBlock(base_channels * 2 ** n_downsampling, use_bias=use_bias)]
self.bottleneck_block = nn.Sequential(*bottleneck_block)
# Upscaling
# A sequence of transposed-conv-blocks, Image dims expand by 2, channels dim shrinks by 2 at each block\
self.upscale_block = RescaleBlock(n_downsampling, 2.0, base_channels, True)
# Final block
# No stride so image dims are kept and channels dim shrinks to 3 (output image channels)
self.final_block = nn.Sequential(nn.ReflectionPad2d(3),
nn.Conv2d(base_channels, 3, kernel_size=7),
nn.Tanh())
def forward(self, input_tensor, output_size, random_affine):
# A condition for having the output at same size as the scaled input is having even output_size
# Entry block
feature_map = self.entry_block(input_tensor)
# Change scale to output scale by interpolation
if random_affine is None:
feature_map = f.interpolate(feature_map, size=output_size, mode='bilinear')
else:
feature_map = self.geo_transform.forward(feature_map, output_size, random_affine)
# Downscale block
feature_map, downscales = self.downscale_block.forward(feature_map, return_all_scales=self.skip)
# Bottleneck (res-blocks)
feature_map = self.bottleneck_block(feature_map)
# Upscale block
feature_map, _ = self.upscale_block.forward(feature_map, pyramid=downscales, skip=self.skip)
# Final block
output_tensor = self.final_block(feature_map)
return output_tensor
Discriminator
class MultiScaleDiscriminator(nn.Module):
def __init__(self, real_crop_size, max_n_scales=9, scale_factor=2, base_channels=128, extra_conv_layers=0):
super(MultiScaleDiscriminator, self).__init__()
self.base_channels = base_channels
self.scale_factor = scale_factor
self.min_size = 16
self.extra_conv_layers = extra_conv_layers
# We want the max num of scales to fit the size of the real examples. further scaling would create networks that
# only train on fake examples
self.max_n_scales = np.min([np.int(np.ceil(np.log(np.min(real_crop_size) * 1.0 / self.min_size)
/ np.log(self.scale_factor))), max_n_scales])
# Prepare a list of all the networks for all the wanted scales
self.nets = nn.ModuleList()
# Create a network for each scale
for _ in range(self.max_n_scales):
self.nets.append(self.make_net())
def make_net(self):
base_channels = self.base_channels
net = []
# Entry block
net += [nn.utils.spectral_norm(nn.Conv2d(3, base_channels, kernel_size=3, stride=1)),
nn.BatchNorm2d(base_channels),
nn.LeakyReLU(0.2, True)]
# Downscaling blocks
# A sequence of strided conv-blocks. Image dims shrink by 2, channels dim expands by 2 at each block
net += [nn.utils.spectral_norm(nn.Conv2d(base_channels, base_channels * 2, kernel_size=3, stride=2)),
nn.BatchNorm2d(base_channels * 2),
nn.LeakyReLU(0.2, True)]
# Regular conv-block
net += [nn.utils.spectral_norm(nn.Conv2d(in_channels=base_channels * 2,
out_channels=base_channels * 2, kernel_size=3, bias=True)),
nn.BatchNorm2d(base_channels * 2),
nn.LeakyReLU(0.2, True)]
# Additional 1x1 conv-blocks
for _ in range(self.extra_conv_layers):
net += [nn.utils.spectral_norm(nn.Conv2d(in_channels=base_channels * 2,out_channels=base_channels * 2,kernel_size=3,bias=True)),
nn.BatchNorm2d(base_channels * 2),
nn.LeakyReLU(0.2, True)]
# Final conv-block
# Ends with a Sigmoid to get a range of 0-1
net += nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(base_channels * 2, 1, kernel_size=1)),nn.Sigmoid())
# Make it a valid layers sequence and return
return nn.Sequential(*net)
def forward(self, input_tensor, scale_weights):
aggregated_result_maps_from_all_scales = self.nets[0](input_tensor) * scale_weights[0]
map_size = aggregated_result_maps_from_all_scales.shape[2:]
# Run all nets over all scales and aggregate the interpolated results
for net, scale_weight, i in zip(self.nets[1:], scale_weights[1:], range(1, len(scale_weights))):
downscaled_image = f.interpolate(input_tensor, scale_factor=self.scale_factor**(-i), mode='bilinear')
result_map_for_current_scale = net(downscaled_image)
upscaled_result_map_for_current_scale = f.interpolate(result_map_for_current_scale,
size=map_size,
mode='bilinear')
aggregated_result_maps_from_all_scales += upscaled_result_map_for_current_scale * scale_weight
return aggregated_result_maps_from_all_scales