Deep Convolutional GAN(DCGAN)
The deep convolutional adversarial pair learns a hierarchy of representations from object parts to scenes in both the generator and discriminator. Additionally, we use the learned features for novel tasks - demonstrating their applicability as general image representations.
DCGAN in PyTorch
Genrator
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.init_size = opt.img_size // 4
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, z):
out = self.l1(z)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
Discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block
self.model = nn.Sequential(
*discriminator_block(opt.channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
# The height and width of downsampled image
ds_size = opt.img_size // 2 ** 4
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
def forward(self, img):
out = self.model(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
return validity
Training Process
adversarial_loss = torch.nn.BCELoss()
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
# Generate a batch of images
gen_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
Least Squares GAN(LSGAN)
Idea & Design
The standard GAN uses a sigmoid cross entropy loss for the discriminator to classify whether its input is real or fake. However, if a generated sample is well classified as real by the discriminator, there would be no reason for the generator to be updated even though the generated sample is located far from the real data distribution. A sigmoid cross entropy loss can barely push such generated samples towards real data distribution since its classification role has been achieved. Motivated by this phenomenon, least-square GAN (LSGAN) replaces a sigmoid cross entropy loss with a least square loss, which directly penalizes fake samples by moving them close to the real data distribution. LSGAN solves the following problems:
where a, b and c refer to the baseline values for the discriminator.
The above equation use a least square loss, under which the discriminator is forced to have designated values (a, b and c) for the real samples and the generated samples, respectively, rather than a probability for the real or fake samples. Thus, in contrary to a sigmoid cross entropy loss, a least square loss not only classifies the real samples and the generated samples but also pushes generated samples closer to the real data distribution.
LSGAN Loss Function in PyTorch
Minimizes MSE instead of BCE
adversarial_loss = torch.nn.MSELoss()
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = 0.5 * (real_loss + fake_loss)
Conditional GAN
Idea & Design
In the original GAN, we have no control of what to be generated, since the output is only dependent on random noise. However, we can add a conditional input c to the random noise z so that the generated image is defined by G(c, z). Typically, the conditional input vector c is concatenated with the noise vector z, and the resulting vector is put into the generator as it is in the original GAN. Besides, we can perform other data augmentation on c and z. The meaning of conditional input c is arbitrary, for example, it can be the class of image, attributes of object or an embedding of text descriptions of the image we want to generate.
Implamentation
CGAN Generator with Label Embedding in PyTorch
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(opt.latent_dim + opt.n_classes, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, noise, labels):
# Concatenate label embedding and image to produce input
gen_input = torch.cat((self.label_emb(labels), noise), -1)
img = self.model(gen_input)
img = img.view(img.size(0), *img_shape)
return img
CGAN Discriminator with Label Embedding in PyTorch
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
self.model = nn.Sequential(
nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1),
)
def forward(self, img, labels):
# Concatenate label embedding and image to produce input
d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
validity = self.model(d_in)
return validity
InfoGAN
Idea & Design
InfoGAN decomposes an input noise vector into a standard incompressible latent vector z and another latent variable c to capture salient semantic features of real samples. Then, InfoGAN maximizes the amount of mutual information between c and a generated sample $G(z, c)$ to allow c to capture some noticeable features of real data. In other words, the generator takes the concatenated input $(z, c)$ and maximizes the mutual information,$ I(c; G(z, c)) $between a given latent code c and the generated samples $G(z, c)$ to learn meaningful feature representations. However, evaluating mutual information $I(c; G(z, c))$ needs to directly estimate the posterior probability $p(c/x)$, which is intractable. InfoGAN, thus, takes a variational approach which replaces a target value $I(c; G(z, c))$ by maximizing a lower bound.
Both CGAN and InfoGAN learn conditional probability $p(x/c)$ given a certain condition vector c; however, they are dissimilar regarding how they handle condition vector c. In CGAN, additional information c is assumed to be semantically known (such as class labels), so we have to provide c to the generator and the discriminator during the training phase. On the other hand, c is assumed to be unknown in InfoGAN, so we take c by sampling from prior distribution $p(c)$ and control the generating process based on $I(c; G(z, c))$. As a result, the automatically inferred c in InfoGAN has much more freedom to capture certain features of real data than c in CGAN, which is restricted to known information.
InfoGAN in PyTorch
InfoGAN Generator: input with class label and latent code
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
input_dim = opt.latent_dim + opt.n_classes + opt.code_dim
self.init_size = opt.img_size // 4 # Initial size before upsampling
self.l1 = nn.Sequential(nn.Linear(input_dim, 128 * self.init_size ** 2))
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, noise, labels, code):
gen_input = torch.cat((noise, labels, code), -1)
out = self.l1(gen_input)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
InfoGAN Discriminator: output real/fake, label and latent code
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
"""Returns layers of each discriminator block"""
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block
self.conv_blocks = nn.Sequential(
*discriminator_block(opt.channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
# The height and width of downsampled image
ds_size = opt.img_size // 2 ** 4
# Output layers
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1))
self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())
self.latent_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.code_dim))
def forward(self, img):
out = self.conv_blocks(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
label = self.aux_layer(out)
latent_code = self.latent_layer(out)
return validity, label, latent_code
Loss and Training of InfoGAN
Loss Function
# Loss functions
adversarial_loss = torch.nn.MSELoss()
categorical_loss = torch.nn.CrossEntropyLoss()
continuous_loss = torch.nn.MSELoss()
Generator and Discriminator Training
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Sample noise and labels as generator input
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
label_input = to_categorical(np.random.randint(0, opt.n_classes, batch_size), num_columns=opt.n_classes)
code_input = Variable(FloatTensor(np.random.uniform(-1, 1, (batch_size, opt.code_dim))))
# Generate a batch of images
gen_imgs = generator(z, label_input, code_input)
# Loss measures generator's ability to fool the discriminator
validity, _, _ = discriminator(gen_imgs)
g_loss = adversarial_loss(validity, valid)
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Loss for real images
real_pred, _, _ = discriminator(real_imgs)
d_real_loss = adversarial_loss(real_pred, valid)
# Loss for fake images
fake_pred, _, _ = discriminator(gen_imgs.detach())
d_fake_loss = adversarial_loss(fake_pred, fake)
# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward()
optimizer_D.step()
Information Loss in InfoGAN
# ------------------
# Information Loss
# ------------------
optimizer_info.zero_grad()
# Sample labels
sampled_labels = np.random.randint(0, opt.n_classes, batch_size)
# Ground truth labels
gt_labels = Variable(LongTensor(sampled_labels), requires_grad=False)
# Sample noise, labels and code as generator input
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
label_input = to_categorical(sampled_labels, num_columns=opt.n_classes)
code_input = Variable(FloatTensor(np.random.uniform(-1, 1, (batch_size, opt.code_dim))))
gen_imgs = generator(z, label_input, code_input)
_, pred_label, pred_code = discriminator(gen_imgs)
info_loss = lambda_cat * categorical_loss(pred_label, gt_labels) + lambda_con * continuous_loss(
pred_code, code_input
)
info_loss.backward()
optimizer_info.step()
Auxiliary Classifier GAN(AC-GAN)
Idea & Design
In order to feed more side-information and to allow for semi-supervised learning, one can add an additional task-specific auxiliary classifier to the discriminator, so that the model is optimized on the original tasks as well as the additional task. The architecture of such method is illustrated in the below figure, where C is the auxiliary classifier.
Adding auxiliary classifiers allows us to use pre-trained models (e.g. image classifiers trained on ImageNet), and experiments in AC-GAN demonstrate that such method can help generating sharper images as well as alleviate the mode collapse problem. Using auxiliary classifiers can also help in applications such as text-to-image synthesis and image-to-image translation.
Implementation
AC-GAN Generator in PyTorch
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)
self.init_size = opt.img_size // 4 # Initial size before upsampling
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, noise, labels):
gen_input = torch.mul(self.label_emb(labels), noise)
out = self.l1(gen_input)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
AC-GAN Discriminator in PyTorch
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
"""Returns layers of each discriminator block"""
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block
self.conv_blocks = nn.Sequential(
*discriminator_block(opt.channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
# The height and width of downsampled image
ds_size = opt.img_size // 2 ** 4
# Output layers
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())
def forward(self, img):
out = self.conv_blocks(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
label = self.aux_layer(out)
return validity, label
Discriminator Loss
# Loss for real images
real_pred, real_aux = discriminator(real_imgs)
d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2
# Loss for fake images
fake_pred, fake_aux = discriminator(gen_imgs.detach())
d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2
Boundary Equilibrium Generative Adversarial Networks(BEGAN)
Idea & Design
Boundary equilibrium GAN (BEGAN) uses the fact that pixelwise loss distribution follows a normal distribution by CLT. It focuses on matching loss distributions through Wasserstein distance and not on directly matching data distributions. In BEGAN, the discriminator has two roles: one is to reconstruct real samples sufficiently and the other is to balance the generator and the discriminator via an equilibrium hyperparameter $γ = E[L(G(z))]/ E[L(x)]$. $γ$ is fed into an objective function to prevent the discriminator from easily winning over the generator; therefore, this balances the power of the two components.
The objective function of BEGAN:
BEGAN Object Function in PyTorch
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
# Generate a batch of images
gen_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
g_loss = torch.mean(torch.abs(discriminator(gen_imgs) - gen_imgs))
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
d_real = discriminator(real_imgs)
d_fake = discriminator(gen_imgs.detach())
d_loss_real = torch.mean(torch.abs(d_real - real_imgs))
d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach()))
d_loss = d_loss_real - k * d_loss_fake
d_loss.backward()
optimizer_D.step()
# ----------------
# Update weights
# ----------------
diff = torch.mean(gamma * d_loss_real - d_loss_fake)
# Update weight term for fake samples
k = k + lambda_k * diff.item()
k = min(max(k, 0), 1) # Constraint to interval [0, 1]
# Update convergence metric
M = (d_loss_real + torch.abs(diff)).data[0]
Coupled GAN(CoGAN)
In contrast to the existing approaches, which require tuples of corresponding images in different domains in the training set, CoGAN can learn a joint distribution without any tuple of corresponding images. It can learn a joint distribution with just samples drawn from the marginal distributions. This is achieved by enforcing a weight-sharing constraint that limits the network capacity and favors a joint distribution solution over a product of marginal distributions one.
Coupled GAN in PyTorch
Coupled Generator
class CoupledGenerators(nn.Module):
def __init__(self):
super(CoupledGenerators, self).__init__()
self.init_size = opt.img_size // 4
self.fc = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
self.shared_conv = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
)
self.G1 = nn.Sequential(
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
)
self.G2 = nn.Sequential(
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, noise):
out = self.fc(noise)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img_emb = self.shared_conv(out)
img1 = self.G1(img_emb)
img2 = self.G2(img_emb)
return img1, img2
Coupled GAN Generator Loss
optimizer_G.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
# Generate a batch of images
gen_imgs1, gen_imgs2 = coupled_generators(z)
# Determine validity of generated images
validity1, validity2 = coupled_discriminators(gen_imgs1, gen_imgs2)
g_loss = (adversarial_loss(validity1, valid) + adversarial_loss(validity2, valid)) / 2
Coupled Discriminator
class CoupledDiscriminators(nn.Module):
def __init__(self):
super(CoupledDiscriminators, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
block.extend([nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)])
return block
self.shared_conv = nn.Sequential(
*discriminator_block(opt.channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
# The height and width of downsampled image
ds_size = opt.img_size // 2 ** 4
self.D1 = nn.Linear(128 * ds_size ** 2, 1)
self.D2 = nn.Linear(128 * ds_size ** 2, 1)
def forward(self, img1, img2):
# Determine validity of first image
out = self.shared_conv(img1)
out = out.view(out.shape[0], -1)
validity1 = self.D1(out)
# Determine validity of second image
out = self.shared_conv(img2)
out = out.view(out.shape[0], -1)
validity2 = self.D2(out)
return validity1, validity2
Coupled GAN Discriminator Loss
optimizer_D.zero_grad()
# Determine validity of real and generated images
validity1_real, validity2_real = coupled_discriminators(imgs1, imgs2)
validity1_fake, validity2_fake = coupled_discriminators(gen_imgs1.detach(), gen_imgs2.detach())
d_loss = (
adversarial_loss(validity1_real, valid)
+ adversarial_loss(validity1_fake, fake)
+ adversarial_loss(validity2_real, valid)
+ adversarial_loss(validity2_fake, fake)
) / 4
Wasserstein GAN(WGAN)
Idea & Design
Compared to the original GAN algorithm, the WGAN undertakes the following changes:
- After every gradient update on the critic function, clamp the weights to a small fixed range, .
- Use a new loss function derived from the Wasserstein distance, no logarithm anymore. The “discriminator” model does not play as a direct critic but a helper for estimating the Wasserstein metric between real and generated data distribution.
- Empirically the authors recommended RMSProp optimizer on the critic, rather than a momentum based optimizer such as Adam which could cause instability in the model training. I haven’t seen clear theoretical explanation on this point through.
WGAN in PyTorch
Training Process
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
# Generate a batch of images
fake_imgs = generator(z).detach()
# Adversarial loss(Wasserstein Distance)
loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
loss_D.backward()
optimizer_D.step()
# Clip weights of discriminator
for p in discriminator.parameters():
p.data.clamp_(-opt.clip_value, opt.clip_value)
# Train the generator every n_critic iterations
if i % opt.n_critic == 0:
optimizer_G.zero_grad()
# Generate a batch of images
gen_imgs = generator(z)
# Adversarial loss
loss_G = -torch.mean(discriminator(gen_imgs))
loss_G.backward()
optimizer_G.step()
Wasserstein GAN with Gradient Penalty(WGAN-GP)
Idea & Design
The recently proposed Wasserstein GAN (WGAN) makes progress toward stable training of GANs, but sometimes can still generate only low-quality samples or fail to converge. We find that these problems are often due to the use of weight clipping in WGAN to enforce a Lipschitz constraint on the critic, which can lead to undesired behavior. We propose an alternative to clipping weights: penalize the norm of gradient of the critic with respect to its input.
WGAN-GP in PyTorch
Training Process
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
# Generate a batch of images
fake_imgs = generator(z)
# Real images
real_validity = discriminator(real_imgs)
# Fake images
fake_validity = discriminator(fake_imgs)
# Gradient penalty
gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
# Adversarial loss
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
d_loss.backward()
optimizer_D.step()
optimizer_G.zero_grad()
# Train the generator every n_critic steps
if i % opt.n_critic == 0:
# Generate a batch of images
fake_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
# Train on fake images
fake_validity = discriminator(fake_imgs)
g_loss = -torch.mean(fake_validity)
g_loss.backward()
optimizer_G.step()
Gradient Penalty in PyTorch
def compute_gradient_penalty(D, real_samples, fake_samples):
"""Calculates the gradient penalty loss for WGAN GP"""
# Random weight term for interpolation between real and fake samples
alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
# Get random interpolation between real and fake samples
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = D(interpolates)
fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
# Get gradient w.r.t. interpolates
gradients = autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
Code example from PyTorch-GAN.
Related
- Deep Generative Models(Part 1): Taxonomy and VAEs
- Deep Generative Models(Part 2): Flow-based Models(include PixelCNN)
- Image to Image Translation(1): pix2pix, S+U, CycleGAN, UNIT, BicycleGAN, and StarGAN
-
Image to Image Translation(2): pix2pixHD, MUNIT, DRIT, vid2vid, SPADE, INIT, and FUNIT
- ICCV 2019: Image Synthesis(Part One)
- ICCV 2019: Image Synthesis(Part Two)
- ICCV 2019: Image and Video Inpainting
- ICCV 2019: Image-to-Image Translation
- ICCV 2019: Face Editing and Manipulation