Info
- Title: SinGAN: Learning a Generative Model From a Single Natural Image
- Task: Image Synthesis
- Author: Tamar Rott Shaham, Tali Dekel, Tomer Michaeli
- Date: May 2019
- Arxiv: 1905.01164
- Published: ICCV 2019(Best Paper Award)
Highlights
- Zero-Shot learning pipeline from single image
- Wild applications for super-resolution, paint-to-image, harmonization, etc
Abstract
We introduce SinGAN, an unconditional generative model that can be learned from a single natural image. Our model is trained to capture the internal distribution of patches within the image, and is then able to generate high quality, diverse samples that carry the same visual content as the image. SinGAN contains a pyramid of fully convolutional GANs, each responsible for learning the patch distribution at a different scale of the image. This allows generating new samples of arbitrary size and aspect ratio, that have significant variability, yet maintain both the global structure and the fine textures of the training image. In contrast to previous single image GAN schemes, our approach is not limited to texture images, and is not conditional (i.e. it generates samples from noise). User studies confirm that the generated samples are commonly confused to be real images. We illustrate the utility of SinGAN in a wide range of image manipulation tasks.
Motivation & Design
Multi-Scale Pipeline.
Our model consists of a pyramid of GANs, where both training and inference are done in a coarse-to-fine fashion. At each scale, Gn learns to generate image samples in which all the overlapping patches cannot be distinguished from the patches in the down-sampled training image, xn , by the discriminator Dn ; the effective patch size decreases as we go up the pyramid (marked in yellow on the original image for illustration). The input to Gn is a random noise image zn , and the generated image from the previous scale x̃n, upsampled to the current resolution (except for the coarsest level which is purely generative). The generation process at level n involves all generators {GN . . . Gn } and all noise maps {zN , . . . , zn } up to this level.
Single Scale Generation
At each scale n, the image from the previous scale, $x̃{n+1}$, is upsampled and added to the input noise map, zn . The result is fed into 5 conv layers, whose output is a residual image that is added back to $\left(\tilde{x}{n+1}\right) \uparrow^{r}$ . This is the output $x̃n$ of Gn.
Loss Functions
Adversarial Loss and Reconstruction Loss
Experiments & Ablation Study
SinGAN can be used in various image manipulation tasks, including: transforming a paint (clipart) into a realistic photo, rearranging and editing objects in the image, harmonizing a new object into an image, image super-resolution and creating an animation from a single input. In all these cases, our model observes only the training image (first row) and is trained in the same manner for all applications, with no architectural changes or further tuning.
Code
Training Process
def train(opt,Gs,Zs,reals,NoiseAmp):
real_ = functions.read_image(opt)
in_s = 0
scale_num = 0
real = imresize(real_,opt.scale1,opt)
reals = functions.creat_reals_pyramid(real,reals,opt)
nfc_prev = 0
while scale_num<opt.stop_scale+1:
opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)
D_curr,G_curr = init_models(opt)
if (nfc_prev==opt.nfc):
G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1)))
D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1)))
z_curr,in_s,G_curr = train_single_scale(D_curr,G_curr,reals,Gs,Zs,in_s,NoiseAmp,opt)
G_curr = functions.reset_grads(G_curr,False)
G_curr.eval()
D_curr = functions.reset_grads(D_curr,False)
D_curr.eval()
Gs.append(G_curr)
Zs.append(z_curr)
NoiseAmp.append(opt.noise_amp)
scale_num+=1
nfc_prev = opt.nfc
del D_curr,G_curr
return
Training on single scale
def train_single_scale(netD,netG,reals,Gs,Zs,in_s,NoiseAmp,opt,centers=None):
real = reals[len(Gs)]
opt.nzx = real.shape[2]#+(opt.ker_size-1)*(opt.num_layer)
opt.nzy = real.shape[3]#+(opt.ker_size-1)*(opt.num_layer)
opt.receptive_field = opt.ker_size + ((opt.ker_size-1)*(opt.num_layer-1))*opt.stride
pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
if opt.mode == 'animation_train':
opt.nzx = real.shape[2]+(opt.ker_size-1)*(opt.num_layer)
opt.nzy = real.shape[3]+(opt.ker_size-1)*(opt.num_layer)
pad_noise = 0
m_noise = nn.ZeroPad2d(int(pad_noise))
m_image = nn.ZeroPad2d(int(pad_image))
alpha = opt.alpha
fixed_noise = functions.generate_noise([opt.nc_z,opt.nzx,opt.nzy])
z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
z_opt = m_noise(z_opt)
# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999))
schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,milestones=[1600],gamma=opt.gamma)
schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,milestones=[1600],gamma=opt.gamma)
errD2plot = []
errG2plot = []
D_real2plot = []
D_fake2plot = []
z_opt2plot = []
for epoch in range(opt.niter):
schedulerD.step()
schedulerG.step()
if (Gs == []) & (opt.mode != 'SR_train'):
z_opt = functions.generate_noise([1,opt.nzx,opt.nzy])
z_opt = m_noise(z_opt.expand(1,3,opt.nzx,opt.nzy))
noise_ = functions.generate_noise([1,opt.nzx,opt.nzy])
noise_ = m_noise(noise_.expand(1,3,opt.nzx,opt.nzy))
else:
noise_ = functions.generate_noise([opt.nc_z,opt.nzx,opt.nzy])
noise_ = m_noise(noise_)
############################
# (1) Update D network: maximize D(x) + D(G(z))
###########################
for j in range(opt.Dsteps):
# train with real
netD.zero_grad()
output = netD(real).to(opt.device)
#D_real_map = output.detach()
errD_real = -output.mean()#-a
errD_real.backward(retain_graph=True)
D_x = -errD_real.item()
# train with fake
if (j==0) & (epoch == 0):
if (Gs == []) & (opt.mode != 'SR_train'):
prev = torch.full([1,opt.nc_z,opt.nzx,opt.nzy], 0, device=opt.device)
in_s = prev
prev = m_image(prev)
z_prev = torch.full([1,opt.nc_z,opt.nzx,opt.nzy], 0, device=opt.device)
z_prev = m_noise(z_prev)
opt.noise_amp = 1
elif opt.mode == 'SR_train':
z_prev = in_s
criterion = nn.MSELoss()
RMSE = torch.sqrt(criterion(real, z_prev))
opt.noise_amp = opt.noise_amp_init * RMSE
z_prev = m_image(z_prev)
prev = z_prev
else:
prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rand',m_noise,m_image,opt)
prev = m_image(prev)
z_prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rec',m_noise,m_image,opt)
criterion = nn.MSELoss()
RMSE = torch.sqrt(criterion(real, z_prev))
opt.noise_amp = opt.noise_amp_init*RMSE
z_prev = m_image(z_prev)
else:
prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rand',m_noise,m_image,opt)
prev = m_image(prev)
if opt.mode == 'paint_train':
prev = functions.quant2centers(prev,centers)
plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
if (Gs == []) & (opt.mode != 'SR_train'):
noise = noise_
else:
noise = opt.noise_amp*noise_+prev
fake = netG(noise.detach(),prev)
output = netD(fake.detach())
errD_fake = output.mean()
errD_fake.backward(retain_graph=True)
D_G_z = output.mean().item()
gradient_penalty = functions.calc_gradient_penalty(netD, real, fake, opt.lambda_grad)
gradient_penalty.backward()
errD = errD_real + errD_fake + gradient_penalty
optimizerD.step()
errD2plot.append(errD.detach())
############################
# (2) Update G network: maximize D(G(z))
###########################
for j in range(opt.Gsteps):
netG.zero_grad()
output = netD(fake)
#D_fake_map = output.detach()
errG = -output.mean()
errG.backward(retain_graph=True)
if alpha!=0:
loss = nn.MSELoss()
if opt.mode == 'paint_train':
z_prev = functions.quant2centers(z_prev, centers)
plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)
Z_opt = opt.noise_amp*z_opt+z_prev
rec_loss = alpha*loss(netG(Z_opt.detach(),z_prev),real)
rec_loss.backward(retain_graph=True)
rec_loss = rec_loss.detach()
else:
Z_opt = z_opt
rec_loss = 0
optimizerG.step()
return z_opt,in_s,netG
Models
class ConvBlock(nn.Sequential):
def __init__(self, in_channel, out_channel, ker_size, padd, stride):
super(ConvBlock,self).__init__()
self.add_module('conv',nn.Conv2d(in_channel ,out_channel,kernel_size=ker_size,stride=stride,padding=padd)),
self.add_module('norm',nn.BatchNorm2d(out_channel)),
self.add_module('LeakyRelu',nn.LeakyReLU(0.2, inplace=True))
def weights_init(m):
classname = m.__class__.__name__
if classname.find('conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('norm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class WDiscriminator(nn.Module):
def __init__(self, opt):
super(WDiscriminator, self).__init__()
self.is_cuda = torch.cuda.is_available()
N = int(opt.nfc)
self.head = ConvBlock(opt.nc_im,N,opt.ker_size,opt.padd_size,1)
self.body = nn.Sequential()
for i in range(opt.num_layer-2):
N = int(opt.nfc/pow(2,(i+1)))
block = ConvBlock(max(2*N,opt.min_nfc),max(N,opt.min_nfc),opt.ker_size,opt.padd_size,1)
self.body.add_module('block%d'%(i+1),block)
self.tail = nn.Conv2d(max(N,opt.min_nfc),1,kernel_size=opt.ker_size,stride=1,padding=opt.padd_size)
def forward(self,x):
x = self.head(x)
x = self.body(x)
x = self.tail(x)
return x
class GeneratorConcatSkip2CleanAdd(nn.Module):
def __init__(self, opt):
super(GeneratorConcatSkip2CleanAdd, self).__init__()
self.is_cuda = torch.cuda.is_available()
N = opt.nfc
self.head = ConvBlock(opt.nc_im,N,opt.ker_size,opt.padd_size,1) #GenConvTransBlock(opt.nc_z,N,opt.ker_size,opt.padd_size,opt.stride)
self.body = nn.Sequential()
for i in range(opt.num_layer-2):
N = int(opt.nfc/pow(2,(i+1)))
block = ConvBlock(max(2*N,opt.min_nfc),max(N,opt.min_nfc),opt.ker_size,opt.padd_size,1)
self.body.add_module('block%d'%(i+1),block)
self.tail = nn.Sequential(
nn.Conv2d(max(N,opt.min_nfc),opt.nc_im,kernel_size=opt.ker_size,stride =1,padding=opt.padd_size),
nn.Tanh()
)
def forward(self,x,y):
x = self.head(x)
x = self.body(x)
x = self.tail(x)
ind = int((y.shape[2]-x.shape[2])/2)
y = y[:,:,ind:(y.shape[2]-ind),ind:(y.shape[3]-ind)]
return x+y