Info
- Title: StackGAN: Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks
- Task: Text-to-Image
- Author: Han Zhang , Tao Xu, Hongsheng Li, Shaoting Zhang, Xiaogang Wang, Xiaolei Huang,, Dimitris Metaxas
- Date: Dec. 2016
- Arxiv: 1612.03242
- Published: ICCV 2017
Highlights
-
Propose a novel Stacked Generative Adversarial Networks for synthesizing photo-realistic images from text descriptions. It decomposes the difficult problem of generating high-resolution images into more manageable subproblems and significantly improve the state of the art. The StackGAN for the first time generates images of 256×256 resolution with photo-realistic details from text descriptions.
-
A new Conditioning Augmentation technique is proposed to stabilize the conditional GAN training and also improves the diversity of the generated samples.
Abstract
Synthesizing high-quality images from text descriptions is a challenging problem in computer vision and has many practical applications. Samples generated by existing text-to-image approaches can roughly reflect the meaning of the given descriptions, but they fail to contain necessary details and vivid object parts. In this paper, we propose Stacked Generative Adversarial Networks (StackGAN) to generate 256x256 photo-realistic images conditioned on text descriptions. We decompose the hard problem into more manageable sub-problems through a sketch-refinement process. The Stage-I GAN sketches the primitive shape and colors of the object based on the given text description, yielding Stage-I low-resolution images. The Stage-II GAN takes Stage-I results and text descriptions as inputs, and generates high-resolution images with photo-realistic details. It is able to rectify defects in Stage-I results and add compelling details with the refinement process. To improve the diversity of the synthesized images and stabilize the training of the conditional-GAN, we introduce a novel Conditioning Augmentation technique that encourages smoothness in the latent conditioning manifold. Extensive experiments and comparisons with state-of-the-arts on benchmark datasets demonstrate that the proposed method achieves significant improvements on generating photo-realistic images conditioned on text descriptions.
Motivation & Design
Overview
- Stage-I GAN: it sketches the primitive shape and basic colors of the object conditioned on the given text description, and draws the background layout from a random noise vector, yielding a low-resolution image.
- Stage-II GAN: it corrects defects in the low-resolution image from Stage-I and completes details of the object by reading the text description again, producing a high- resolution photo-realistic image.
The Stage-I generator draws a low-resolution image by sketching rough shape and basic colors of the object from the given text and painting the background from a random noise vector. Conditioned on Stage-I results, the Stage-II generator corrects defects and adds compelling details into Stage-I results, yielding a more realistic high-resolution image.
Stage-I GAN Loss
where the real image $I_0$ and the text description $t$ are fromthe true data distribution pdata. z is a noise vector randomly sampled from a given distribution pz (Gaussian distribution in this paper). λ is a regularization parameter that balances of loss terms.
Stage-II GAN Loss
Different from the original GAN formulation, the random noise z is not used in this stage with the assumption tha the randomness has already been preserved by $s_0$. Gaussian conditioning variables $ĉ$ used in this stage and ĉ0 usedin Stage-I GAN share the same pre-trained text encodergenerating the same text embedding $φ_t$.
Experiments & Ablation Study
Code
StackGAN in PyTorch
StackGAN Stage-1 Generator
class STAGE1_G(nn.Module):
def __init__(self):
super(STAGE1_G, self).__init__()
self.gf_dim = cfg.GAN.GF_DIM * 8
self.ef_dim = cfg.GAN.CONDITION_DIM
self.z_dim = cfg.Z_DIM
self.define_module()
def define_module(self):
ninput = self.z_dim + self.ef_dim
ngf = self.gf_dim
# TEXT.DIMENSION -> GAN.CONDITION_DIM
self.ca_net = CA_NET()
# -> ngf x 4 x 4
self.fc = nn.Sequential(
nn.Linear(ninput, ngf * 4 * 4, bias=False),
nn.BatchNorm1d(ngf * 4 * 4),
nn.ReLU(True))
# ngf x 4 x 4 -> ngf/2 x 8 x 8
self.upsample1 = upBlock(ngf, ngf // 2)
# -> ngf/4 x 16 x 16
self.upsample2 = upBlock(ngf // 2, ngf // 4)
# -> ngf/8 x 32 x 32
self.upsample3 = upBlock(ngf // 4, ngf // 8)
# -> ngf/16 x 64 x 64
self.upsample4 = upBlock(ngf // 8, ngf // 16)
# -> 3 x 64 x 64
self.img = nn.Sequential(
conv3x3(ngf // 16, 3),
nn.Tanh())
def forward(self, text_embedding, noise):
c_code, mu, logvar = self.ca_net(text_embedding)
z_c_code = torch.cat((noise, c_code), 1)
h_code = self.fc(z_c_code)
h_code = h_code.view(-1, self.gf_dim, 4, 4)
h_code = self.upsample1(h_code)
h_code = self.upsample2(h_code)
h_code = self.upsample3(h_code)
h_code = self.upsample4(h_code)
# state size 3 x 64 x 64
fake_img = self.img(h_code)
return None, fake_img, mu, logvar
StackGAN Stage-1 Discriminator
class STAGE1_D(nn.Module):
def __init__(self):
super(STAGE1_D, self).__init__()
self.df_dim = cfg.GAN.DF_DIM
self.ef_dim = cfg.GAN.CONDITION_DIM
self.define_module()
def define_module(self):
ndf, nef = self.df_dim, self.ef_dim
self.encode_img = nn.Sequential(
nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size (ndf*2) x 16 x 16
nn.Conv2d(ndf*2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size (ndf*4) x 8 x 8
nn.Conv2d(ndf*4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
# state size (ndf * 8) x 4 x 4)
nn.LeakyReLU(0.2, inplace=True)
)
self.get_cond_logits = D_GET_LOGITS(ndf, nef)
self.get_uncond_logits = None
def forward(self, image):
img_embedding = self.encode_img(image)
return img_embedding
StackGAN Stage-2 Generator
class STAGE2_G(nn.Module):
def __init__(self, STAGE1_G):
super(STAGE2_G, self).__init__()
self.gf_dim = cfg.GAN.GF_DIM
self.ef_dim = cfg.GAN.CONDITION_DIM
self.z_dim = cfg.Z_DIM
self.STAGE1_G = STAGE1_G
# fix parameters of stageI GAN
for param in self.STAGE1_G.parameters():
param.requires_grad = False
self.define_module()
def _make_layer(self, block, channel_num):
layers = []
for i in range(cfg.GAN.R_NUM):
layers.append(block(channel_num))
return nn.Sequential(*layers)
def define_module(self):
ngf = self.gf_dim
# TEXT.DIMENSION -> GAN.CONDITION_DIM
self.ca_net = CA_NET()
# --> 4ngf x 16 x 16
self.encoder = nn.Sequential(
conv3x3(3, ngf),
nn.ReLU(True),
nn.Conv2d(ngf, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.Conv2d(ngf * 2, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True))
self.hr_joint = nn.Sequential(
conv3x3(self.ef_dim + ngf * 4, ngf * 4),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True))
self.residual = self._make_layer(ResBlock, ngf * 4)
# --> 2ngf x 32 x 32
self.upsample1 = upBlock(ngf * 4, ngf * 2)
# --> ngf x 64 x 64
self.upsample2 = upBlock(ngf * 2, ngf)
# --> ngf // 2 x 128 x 128
self.upsample3 = upBlock(ngf, ngf // 2)
# --> ngf // 4 x 256 x 256
self.upsample4 = upBlock(ngf // 2, ngf // 4)
# --> 3 x 256 x 256
self.img = nn.Sequential(
conv3x3(ngf // 4, 3),
nn.Tanh())
def forward(self, text_embedding, noise):
_, stage1_img, _, _ = self.STAGE1_G(text_embedding, noise)
stage1_img = stage1_img.detach()
encoded_img = self.encoder(stage1_img)
c_code, mu, logvar = self.ca_net(text_embedding)
c_code = c_code.view(-1, self.ef_dim, 1, 1)
c_code = c_code.repeat(1, 1, 16, 16)
i_c_code = torch.cat([encoded_img, c_code], 1)
h_code = self.hr_joint(i_c_code)
h_code = self.residual(h_code)
h_code = self.upsample1(h_code)
h_code = self.upsample2(h_code)
h_code = self.upsample3(h_code)
h_code = self.upsample4(h_code)
fake_img = self.img(h_code)
return stage1_img, fake_img, mu, logvar
StackGAN Stage-2 Discriminator
class STAGE2_D(nn.Module):
def __init__(self):
super(STAGE2_D, self).__init__()
self.df_dim = cfg.GAN.DF_DIM
self.ef_dim = cfg.GAN.CONDITION_DIM
self.define_module()
def define_module(self):
ndf, nef = self.df_dim, self.ef_dim
self.encode_img = nn.Sequential(
nn.Conv2d(3, ndf, 4, 2, 1, bias=False), # 128 * 128 * ndf
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True), # 64 * 64 * ndf * 2
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True), # 32 * 32 * ndf * 4
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True), # 16 * 16 * ndf * 8
nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 16),
nn.LeakyReLU(0.2, inplace=True), # 8 * 8 * ndf * 16
nn.Conv2d(ndf * 16, ndf * 32, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 32),
nn.LeakyReLU(0.2, inplace=True), # 4 * 4 * ndf * 32
conv3x3(ndf * 32, ndf * 16),
nn.BatchNorm2d(ndf * 16),
nn.LeakyReLU(0.2, inplace=True), # 4 * 4 * ndf * 16
conv3x3(ndf * 16, ndf * 8),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True) # 4 * 4 * ndf * 8
)
self.get_cond_logits = D_GET_LOGITS(ndf, nef, bcondition=True)
self.get_uncond_logits = D_GET_LOGITS(ndf, nef, bcondition=False)
def forward(self, image):
img_embedding = self.encode_img(image)
return img_embedding
Training Process
# (1) Prepare training data
real_img_cpu, txt_embedding = data
real_imgs = Variable(real_img_cpu)
txt_embedding = Variable(txt_embedding)
if cfg.CUDA:
real_imgs = real_imgs.cuda()
txt_embedding = txt_embedding.cuda()
# (2) Generate fake images
noise.data.normal_(0, 1)
inputs = (txt_embedding, noise)
_, fake_imgs, mu, logvar = \
nn.parallel.data_parallel(netG, inputs, self.gpus)
# (3) Update D network
netD.zero_grad()
errD, errD_real, errD_wrong, errD_fake = \
compute_discriminator_loss(netD, real_imgs, fake_imgs,
real_labels, fake_labels,
mu, self.gpus)
errD.backward()
optimizerD.step()
# (2) Update G network
netG.zero_grad()
errG = compute_generator_loss(netD, fake_imgs,
real_labels, mu, self.gpus)
kl_loss = KL_loss(mu, logvar)
errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
errG_total.backward()
optimizerG.step()