U-GAT-IT: PyTorch Implementation

 

Info

U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation

Junho Kim (NCSOFT), Minjae Kim (NCSOFT), Hyeonwoo Kang (NCSOFT), Kwanghee Lee (Boeing Korea)

Abstract * We propose a novel method for unsupervised image-to-image translation, which incorporates a new attention module and a new learnable normalization function in an end-to-end manner. The attention module guides our model to focus on more important regions distinguishing between source and target domains based on the attention map obtained by the auxiliary classifier. Unlike previous attention-based methods which cannot handle the geometric changes between domains, our model can translate both images requiring holistic changes and images requiring large shape changes. Moreover, our new AdaLIN (Adaptive Layer-Instance Normalization) function helps our attention-guided model to flexibly control the amount of change in shape and texture by learned parameters depending on datasets. Experimental results show the superiority of the proposed method compared to the existing state-of-the-art models with a fixed network architecture and hyper-parameters.https://github.com/creke)

Usage

├── dataset
   └── YOUR_DATASET_NAME
       ├── trainA
           ├── xxx.jpg (name, format doesn't matter)
           ├── yyy.png
           └── ...
       ├── trainB
           ├── zzz.jpg
           ├── www.png
           └── ...
       ├── testA
           ├── aaa.jpg 
           ├── bbb.png
           └── ...
       └── testB
           ├── ccc.jpg 
           ├── ddd.png
           └── ...

Train

> python main.py --dataset selfie2anime
  • If the memory of gpu is not sufficient, set --light to True
    • But it may not perform well
    • paper version is --light to False

Test

> python main.py --dataset selfie2anime --phase test

Core Design

Attention Module based on CAM(Class Activation Map)

Class Activation Map:

def __init__()
		self.gap_fc = nn.Linear(ngf * mult, 1, bias=False)
        self.gmp_fc = nn.Linear(ngf * mult, 1, bias=False)
        self.conv1x1 = nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=1, stride=1, bias=True)
        self.relu = nn.ReLU(True)

Fully Connected Layer which output weight as importance:

def forward(self, input):
    x = self.DownBlock(input)

    gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
    gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
    gap_weight = list(self.gap_fc.parameters())[0]
    gap = x * gap_weight.unsqueeze(2).unsqueeze(3)

    gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
    gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
    gmp_weight = list(self.gmp_fc.parameters())[0]
    gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)

    cam_logit = torch.cat([gap_logit, gmp_logit], 1)
    x = torch.cat([gap, gmp], 1)
    x = self.relu(self.conv1x1(x))

AdaLIN(Adaptive Layer-Instance Normalization)

AdaLIN

where μI , μL and σI , σL are channel-wise, layer-wise mean and standard deviation respectively, γ and β are parameters generated by the fully connected layer, τ is the learning rate and ∆ρ indicates the parameter update vector (e.g., the gradient) determined by the optimizer. The values of ρ are constrained to the range of [0, 1] simply by imposing bounds at the parameter update step. Generator adjusts the value so that the value of ρ is close to 1 in the task where the instance normalization is important and the value of ρ is close to 0 in the task where the LN is important. The value of ρ is initialized to 1 in the residual blocks of the decoder and 0 in the up-sampling blocks of the decoder.

Learn gamma and beta from fully-connected sub-network:

if self.light:
    FC = [nn.Linear(ngf * mult, ngf * mult, bias=False),
          nn.ReLU(True),
          nn.Linear(ngf * mult, ngf * mult, bias=False),
          nn.ReLU(True)]
else:
    FC = [nn.Linear(img_size // mult * img_size // mult * ngf * mult, ngf * mult, bias=False),
          nn.ReLU(True),
          nn.Linear(ngf * mult, ngf * mult, bias=False),
          nn.ReLU(True)]
self.gamma = nn.Linear(ngf * mult, ngf * mult, bias=False)
self.beta = nn.Linear(ngf * mult, ngf * mult, bias=False)

Adaptive Instance Layer Normalization Module:


class adaILN(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super(adaILN, self).__init__()
        self.eps = eps
        self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.rho.data.fill_(0.9)

    def forward(self, input, gamma, beta):
        in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
        out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
        ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True)
        out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
        out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
        out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)

        return out

Training Process

for step in range(start_iter, self.iteration + 1):
    if self.decay_flag and step > (self.iteration // 2):
        self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2))
        self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2))

    try:
        real_A, _ = trainA_iter.next()
    except:
        trainA_iter = iter(self.trainA_loader)
        real_A, _ = trainA_iter.next()

    try:
        real_B, _ = trainB_iter.next()
    except:
        trainB_iter = iter(self.trainB_loader)
        real_B, _ = trainB_iter.next()

    real_A, real_B = real_A.to(self.device), real_B.to(self.device)

    # Update D
    self.D_optim.zero_grad()

    fake_A2B, _, _ = self.genA2B(real_A)
    fake_B2A, _, _ = self.genB2A(real_B)

    real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
    real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
    real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
    real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

    fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
    fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
    fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
    fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

    D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(self.device)) + self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device))
    D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(self.device)) + self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device))
    D_ad_loss_LA = self.MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(self.device)) + self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device))
    D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(self.device)) + self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device))
    D_ad_loss_GB = self.MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(self.device)) + self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device))
    D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(self.device)) + self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device))
    D_ad_loss_LB = self.MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(self.device)) + self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device))
    D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(self.device)) + self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device))

    D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA)
    D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB)

    Discriminator_loss = D_loss_A + D_loss_B
    Discriminator_loss.backward()
    self.D_optim.step()

    # Update G
    self.G_optim.zero_grad()

    fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
    fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

    fake_A2B2A, _, _ = self.genB2A(fake_A2B)
    fake_B2A2B, _, _ = self.genA2B(fake_B2A)

    fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
    fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

    fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
    fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
    fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
    fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

    G_ad_loss_GA = self.MSE_loss(fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device))
    G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device))
    G_ad_loss_LA = self.MSE_loss(fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device))
    G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device))
    G_ad_loss_GB = self.MSE_loss(fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device))
    G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device))
    G_ad_loss_LB = self.MSE_loss(fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device))
    G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device))

    G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
    G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)

    G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
    G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

    G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(self.device)) + self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device))
    G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(self.device)) + self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device))

    G_loss_A =  self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
    G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

    Generator_loss = G_loss_A + G_loss_B
    Generator_loss.backward()
    self.G_optim.step()

    # clip parameter of AdaILN and ILN, applied after optimizer step
    self.genA2B.apply(self.Rho_clipper)
    self.genB2A.apply(self.Rho_clipper)

code from znxlwm/UGATIT-pytorch.