PyTorch Code for BicycleGAN

 

Info

Title: Toward multimodal image-to-image translation

PyTorch Code Project Page Paper Video Note

Prerequisites

  • Linux or macOS
  • Python 3
  • CPU or NVIDIA GPU + CUDA CuDNN

Getting Started

Installation

  • Clone this repo:
    git clone -b master --single-branch https://github.com/junyanz/BicycleGAN.git
    cd BicycleGAN
    
  • Install PyTorch and dependencies from http://pytorch.org
  • Install python libraries visdom, dominate, and moviepy.

For pip users:

bash ./scripts/install_pip.sh

For conda users:

bash ./scripts/install_conda.sh

Use a Pre-trained Model

  • Download some test photos (e.g., edges2shoes):
    bash ./datasets/download_testset.sh edges2shoes
    
  • Download a pre-trained model (e.g., edges2shoes):
    bash ./pretrained_models/download_model.sh edges2shoes
    
  • Generate results with the model
    bash ./scripts/test_edges2shoes.sh
    

    The test results will be saved to a html file here: ./results/edges2shoes/val/index.html.

  • Generate results with synchronized latent vectors
    bash ./scripts/test_edges2shoes.sh --sync
    

    Results can be found at ./results/edges2shoes/val_sync/index.html.

Generate Morphing Videos

  • We can also produce a morphing video similar to this GIF and Youtube video.
    bash ./scripts/video_edges2shoes.sh
    

    Results can be found at ./videos/edges2shoes/.

Model Training

  • To train a model, download the training images (e.g., edges2shoes).
    bash ./datasets/download_dataset.sh edges2shoes
    
  • Train a model:
    bash ./scripts/train_edges2shoes.sh
    
  • To view training results and loss plots, run python -m visdom.server and click the URL http://localhost:8097. To see more intermediate results, check out ./checkpoints/edges2shoes_bicycle_gan/web/index.html
  • See more training details for other datasets in ./scripts/train.sh.

Datasets (from pix2pix)

Download the datasets using the following script. Many of the datasets are collected by other researchers. Please cite their papers if you use the data.

  • Download the testset.
    bash ./datasets/download_testset.sh dataset_name
    
  • Download the training and testset.
    bash ./datasets/download_dataset.sh dataset_name
    
  • facades: 400 images from CMP Facades dataset. [Citation]
  • maps: 1096 training images scraped from Google Maps
  • edges2shoes: 50k training images from UT Zappos50K dataset. Edges are computed by HED edge detector + post-processing. [Citation]
  • edges2handbags: 137K Amazon Handbag images from iGAN project. Edges are computed by HED edge detector + post-processing. [Citation]
  • night2day: around 20K natural scene images from Transient Attributes dataset [Citation]

Models

Download the pre-trained models with the following script.

bash ./pretrained_models/download_model.sh model_name
  • edges2shoes (edge -> photo) trained on UT Zappos50K dataset.
  • edges2handbags (edge -> photo) trained on Amazon handbags images..
    bash ./pretrained_models/download_model.sh edges2handbags
    bash ./datasets/download_testset.sh edges2handbags
    bash ./scripts/test_edges2handbags.sh
    
  • night2day (nighttime scene -> daytime scene) trained on around 100 webcams.
    bash ./pretrained_models/download_model.sh night2day
    bash ./datasets/download_testset.sh night2day
    bash ./scripts/test_night2day.sh
    
  • facades (facade label -> facade photo) trained on the CMP Facades dataset.
    bash ./pretrained_models/download_model.sh facades
    bash ./datasets/download_testset.sh facades
    bash ./scripts/test_facades.sh
    
  • maps (map photo -> aerial photo) trained on 1096 training images scraped from Google Maps.
    bash ./pretrained_models/download_model.sh maps
    bash ./datasets/download_testset.sh maps
    bash ./scripts/test_maps.sh
    

Core Design

class BiCycleGANModel(BaseModel):
    def forward(self):
        # get real images
        half_size = self.opt.batch_size // 2
        # A1, B1 for encoded; A2, B2 for random
        self.real_A_encoded = self.real_A[0:half_size]
        self.real_B_encoded = self.real_B[0:half_size]
        self.real_B_random = self.real_B[half_size:]
        # get encoded z
        self.z_encoded, self.mu, self.logvar = self.encode(self.real_B_encoded)
        # get random z
        self.z_random = self.get_z_random(self.real_A_encoded.size(0), self.opt.nz)
        # generate fake_B_encoded
        self.fake_B_encoded = self.netG(self.real_A_encoded, self.z_encoded)
        # generate fake_B_random
        self.fake_B_random = self.netG(self.real_A_encoded, self.z_random)
        if self.opt.conditional_D:   # tedious conditoinal data
            self.fake_data_encoded = torch.cat([self.real_A_encoded, self.fake_B_encoded], 1)
            self.real_data_encoded = torch.cat([self.real_A_encoded, self.real_B_encoded], 1)
            self.fake_data_random = torch.cat([self.real_A_encoded, self.fake_B_random], 1)
            self.real_data_random = torch.cat([self.real_A[half_size:], self.real_B_random], 1)
        else:
            self.fake_data_encoded = self.fake_B_encoded
            self.fake_data_random = self.fake_B_random
            self.real_data_encoded = self.real_B_encoded
            self.real_data_random = self.real_B_random

        # compute z_predict
        if self.opt.lambda_z > 0.0:
            self.mu2, logvar2 = self.netE(self.fake_B_random)  # mu2 is a point estimate

    def backward_D(self, netD, real, fake):
        # Fake, stop backprop to the generator by detaching fake_B
        pred_fake = netD(fake.detach())
        # real
        pred_real = netD(real)
        loss_D_fake, _ = self.criterionGAN(pred_fake, False)
        loss_D_real, _ = self.criterionGAN(pred_real, True)
        # Combined loss
        loss_D = loss_D_fake + loss_D_real
        loss_D.backward()
        return loss_D, [loss_D_fake, loss_D_real]

    def backward_G_GAN(self, fake, netD=None, ll=0.0):
        if ll > 0.0:
            pred_fake = netD(fake)
            loss_G_GAN, _ = self.criterionGAN(pred_fake, True)
        else:
            loss_G_GAN = 0
        return loss_G_GAN * ll

    def backward_EG(self):
        # 1, G(A) should fool D
        self.loss_G_GAN = self.backward_G_GAN(self.fake_data_encoded, self.netD, self.opt.lambda_GAN)
        if self.opt.use_same_D:
            self.loss_G_GAN2 = self.backward_G_GAN(self.fake_data_random, self.netD, self.opt.lambda_GAN2)
        else:
            self.loss_G_GAN2 = self.backward_G_GAN(self.fake_data_random, self.netD2, self.opt.lambda_GAN2)
        # 2. KL loss
        if self.opt.lambda_kl > 0.0:
            self.loss_kl = torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) * (-0.5 * self.opt.lambda_kl)
        else:
            self.loss_kl = 0
        # 3, reconstruction |fake_B-real_B|
        if self.opt.lambda_L1 > 0.0:
            self.loss_G_L1 = self.criterionL1(self.fake_B_encoded, self.real_B_encoded) * self.opt.lambda_L1
        else:
            self.loss_G_L1 = 0.0

        self.loss_G = self.loss_G_GAN + self.loss_G_GAN2 + self.loss_G_L1 + self.loss_kl
        self.loss_G.backward(retain_graph=True)

    def update_D(self):
        self.set_requires_grad([self.netD, self.netD2], True)
        # update D1
        if self.opt.lambda_GAN > 0.0:
            self.optimizer_D.zero_grad()
            self.loss_D, self.losses_D = self.backward_D(self.netD, self.real_data_encoded, self.fake_data_encoded)
            if self.opt.use_same_D:
                self.loss_D2, self.losses_D2 = self.backward_D(self.netD, self.real_data_random, self.fake_data_random)
            self.optimizer_D.step()

        if self.opt.lambda_GAN2 > 0.0 and not self.opt.use_same_D:
            self.optimizer_D2.zero_grad()
            self.loss_D2, self.losses_D2 = self.backward_D(self.netD2, self.real_data_random, self.fake_data_random)
            self.optimizer_D2.step()

    def backward_G_alone(self):
        # 3, reconstruction |(E(G(A, z_random)))-z_random|
        if self.opt.lambda_z > 0.0:
            self.loss_z_L1 = torch.mean(torch.abs(self.mu2 - self.z_random)) * self.opt.lambda_z
            self.loss_z_L1.backward()
        else:
            self.loss_z_L1 = 0.0

    def update_G_and_E(self):
        # update G and E
        self.set_requires_grad([self.netD, self.netD2], False)
        self.optimizer_E.zero_grad()
        self.optimizer_G.zero_grad()
        self.backward_EG()
        self.optimizer_G.step()
        self.optimizer_E.step()
        # update G only
        if self.opt.lambda_z > 0.0:
            self.optimizer_G.zero_grad()
            self.optimizer_E.zero_grad()
            self.backward_G_alone()
            self.optimizer_G.step()
            
            
    def optimize_parameters(self):
        self.forward()
        self.update_G_and_E()
        self.update_D()