PyTorch Code for CycleGAN

 

Info

Title: Image-to-Image Translation with Conditional Adversarial Networks

PyTorch Code Project Paper Torch Note

Prerequisites

  • Linux or macOS
  • Python 3
  • CPU or NVIDIA GPU + CUDA CuDNN

Getting Started

Installation

  • Clone this repo:
    git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
    cd pytorch-CycleGAN-and-pix2pix
    
  • Install PyTorch 0.4+ and other dependencies (e.g., torchvision, visdom and dominate).
    • For pip users, please type the command pip install -r requirements.txt.
    • For Conda users, we provide a installation script ./scripts/conda_deps.sh. Alternatively, you can create a new Conda environment using conda env create -f environment.yml.
    • For Docker users, we provide the pre-built Docker image and Dockerfile. Please refer to our Docker page.

train/test

  • Download a CycleGAN dataset (e.g. maps):
    bash ./datasets/download_cyclegan_dataset.sh maps
    
  • To view training results and loss plots, run python -m visdom.server and click the URL http://localhost:8097.
  • Train a model:
    #!./scripts/train_cyclegan.sh
    python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
    

    To see more intermediate results, check out ./checkpoints/maps_cyclegan/web/index.html.

  • Test the model:
    #!./scripts/test_cyclegan.sh
    python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
    
  • The test results will be saved to a html file here: ./results/maps_cyclegan/latest_test/index.html.

Apply a pre-trained model

  • You can download a pretrained model (e.g. horse2zebra) with the following script:
    bash ./scripts/download_cyclegan_model.sh horse2zebra
    
  • The pretrained model is saved at ./checkpoints/{name}_pretrained/latest_net_G.pth. Check here for all the available CycleGAN models.
  • To test the model, you also need to download the horse2zebra dataset:
    bash ./datasets/download_cyclegan_dataset.sh horse2zebra
    
  • Then generate the results using
    python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
    
  • The option --model test is used for generating results of CycleGAN only for one side. This option will automatically set --dataset_mode single, which only loads the images from one set. On the contrary, using --model cycle_gan requires loading and generating results in both directions, which is sometimes unnecessary. The results will be saved at ./results/. Use --results_dir {directory_path_to_save_result} to specify the results directory.
  • For your own experiments, you might want to specify --netG, --norm, --no_dropout to match the generator architecture of the trained model.

Overview of Code Structure

To help users better understand and use our codebase, we briefly overview the functionality and implementation of each package and each module. Please see the documentation in each file for more details. If you have questions, you may find useful information in training/test tips and frequently asked questions.

train.py is a general-purpose training script. It works for various models (with option --model: e.g., pix2pix, cyclegan, colorization) and different datasets (with option --dataset_mode: e.g., aligned, unaligned, single, colorization). See the main README and training/test tips for more details.

test.py is a general-purpose test script. Once you have trained your model with train.py, you can use this script to test the model. It will load a saved model from --checkpoints_dir and save the results to --results_dir. See the main README and training/test tips for more details.

data directory contains all the modules related to data loading and preprocessing. To add a custom dataset class called dummy, you need to add a file called dummy_dataset.py and define a subclass DummyDataset inherited from BaseDataset. You need to implement four functions: __init__ (initialize the class, you need to first call BaseDataset.__init__(self, opt)), __len__ (return the size of dataset), __getitem__ (get a data point), and optionally modify_commandline_options (add dataset-specific options and set default options). Now you can use the dataset class by specifying flag --dataset_mode dummy. See our template dataset class for an example. Below we explain each file in details.

  • __init__.py implements the interface between this package and training and test scripts. train.py and test.py call from data import create_dataset and dataset = create_dataset(opt) to create a dataset given the option opt.
  • base_dataset.py implements an abstract base class (ABC) for datasets. It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
  • image_folder.py implements an image folder class. We modify the official PyTorch image folder code so that this class can load images from both the current directory and its subdirectories.
  • template_dataset.py provides a dataset template with detailed documentation. Check out this file if you plan to implement your own dataset.
  • aligned_dataset.py includes a dataset class that can load image pairs. It assumes a single image directory /path/to/data/train, which contains image pairs in the form of {A,B}. See here on how to prepare aligned datasets. During test time, you need to prepare a directory /path/to/data/test as test data.
  • unaligned_dataset.py includes a dataset class that can load unaligned/unpaired datasets. It assumes that two directories to host training images from domain A /path/to/data/trainA and from domain B /path/to/data/trainB respectively. Then you can train the model with the dataset flag --dataroot /path/to/data. Similarly, you need to prepare two directories /path/to/data/testA and /path/to/data/testB during test time.
  • single_dataset.py includes a dataset class that can load a set of single images specified by the path --dataroot /path/to/data. It can be used for generating CycleGAN results only for one side with the model option -model test.
  • colorization_dataset.py implements a dataset class that can load a set of nature images in RGB, and convert RGB format into (L, ab) pairs in Lab color space. It is required by pix2pix-based colorization model (--model colorization).

models directory contains modules related to objective functions, optimizations, and network architectures. To add a custom model class called dummy, you need to add a file called dummy_model.py and define a subclass DummyModel inherited from BaseModel. You need to implement four functions: __init__ (initialize the class; you need to first call BaseModel.__init__(self, opt)), set_input (unpack data from dataset and apply preprocessing), forward (generate intermediate results), optimize_parameters (calculate loss, gradients, and update network weights), and optionally modify_commandline_options (add model-specific options and set default options). Now you can use the model class by specifying flag --model dummy. See our template model class for an example. Below we explain each file in details.

  • __init__.py implements the interface between this package and training and test scripts. train.py and test.py call from models import create_model and model = create_model(opt) to create a model given the option opt. You also need to call model.setup(opt) to properly initialize the model.
  • base_model.py implements an abstract base class (ABC) for models. It also includes commonly used helper functions (e.g., setup, test, update_learning_rate, save_networks, load_networks), which can be later used in subclasses.
  • template_model.py provides a model template with detailed documentation. Check out this file if you plan to implement your own model.
  • pix2pix_model.py implements the pix2pix model, for learning a mapping from input images to output images given paired data. The model training requires --dataset_mode aligned dataset. By default, it uses a --netG unet256 U-Net generator, a --netD basic discriminator (PatchGAN), and a --gan_mode vanilla GAN loss (standard cross-entropy objective).
  • colorization_model.py implements a subclass of Pix2PixModel for image colorization (black & white image to colorful image). The model training requires -dataset_model colorization dataset. It trains a pix2pix model, mapping from L channel to ab channels in Lab color space. By default, the colorization dataset will automatically set --input_nc 1 and --output_nc 2.
  • cycle_gan_model.py implements the CycleGAN model, for learning image-to-image translation without paired data. The model training requires --dataset_mode unaligned dataset. By default, it uses a --netG resnet_9blocks ResNet generator, a --netD basic discriminator (PatchGAN introduced by pix2pix), and a least-square GANs objective (--gan_mode lsgan).
  • networks.py module implements network architectures (both generators and discriminators), as well as normalization layers, initialization methods, optimization scheduler (i.e., learning rate policy), and GAN objective function (vanilla, lsgan, wgangp).
  • test_model.py implements a model that can be used to generate CycleGAN results for only one direction. This model will automatically set --dataset_mode single, which only loads the images from one set. See the test instruction for more details.

options directory includes our option modules: training options, test options, and basic options (used in both training and test). TrainOptions and TestOptions are both subclasses of BaseOptions. They will reuse the options defined in BaseOptions.

  • __init__.py is required to make Python treat the directory options as containing packages,
  • base_options.py includes options that are used in both training and test. It also implements a few helper functions such as parsing, printing, and saving the options. It also gathers additional options defined in modify_commandline_options functions in both dataset class and model class.
  • train_options.py includes options that are only used during training time.
  • test_options.py includes options that are only used during test time.

util directory includes a miscellaneous collection of useful helper functions.

  • __init__.py is required to make Python treat the directory util as containing packages,
  • get_data.py provides a Python script for downloading CycleGAN and pix2pix datasets. Alternatively, You can also use bash scripts such as download_pix2pix_model.sh and download_cyclegan_model.sh.
  • html.py implements a module that saves images into a single HTML file. It consists of functions such as add_header (add a text header to the HTML file), add_images (add a row of images to the HTML file), save (save the HTML to the disk). It is based on Python library dominate, a Python library for creating and manipulating HTML documents using a DOM API.
  • image_pool.py implements an image buffer that stores previously generated images. This buffer enables us to update discriminators using a history of generated images rather than the ones produced by the latest generators. The original idea was discussed in this paper. The size of the buffer is controlled by the flag --pool_size.
  • visualizer.py includes several functions that can display/save images and print/save logging information. It uses a Python library visdom for display and a Python library dominate (wrapped in HTML) for creating HTML files with images.
  • util.py consists of simple helper functions such as tensor2im (convert a tensor array to a numpy image array), diagnose_network (calculate and print the mean of average absolute value of gradients), and mkdirs (create multiple directories).

Core Design

CycleGAN PyTorch

Model

class CycleGANModel(BaseModel):
    """
    This class implements the CycleGAN model, for learning image-to-image translation without paired data.

    The model training requires '--dataset_mode unaligned' dataset.
    By default, it uses a '--netG resnet_9blocks' ResNet generator,
    a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
    and a least-square GANs objective ('--gan_mode lsgan').

    CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
    """

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()


ResnetGenerator

class ResnetGenerator(nn.Module):
    """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.

    We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
    """

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
        """Construct a Resnet-based generator

        Parameters:
            input_nc (int)      -- the number of channels in input images
            output_nc (int)     -- the number of channels in output images
            ngf (int)           -- the number of filters in the last conv layer
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers
            n_blocks (int)      -- the number of ResNet blocks
            padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
        """
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):       # add ResNet blocks

            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

        for i in range(n_downsampling):  # add upsampling layers
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        """Standard forward"""
        return self.model(input)


UnetGenerator

class UnetGenerator(nn.Module):
    """Create a Unet-based generator"""

    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
        """Construct a Unet generator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            output_nc (int) -- the number of channels in output images
            num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
                                image of size 128x128 will become of size 1x1 # at the bottleneck
            ngf (int)       -- the number of filters in the last conv layer
            norm_layer      -- normalization layer

        We construct the U-Net from the innermost layer to the outermost layer.
        It is a recursive process.
        """
        super(UnetGenerator, self).__init__()
        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layer
        for i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        # gradually reduce the number of filters from ngf * 8 to ngf
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layer

    def forward(self, input):
        """Standard forward"""
        return self.model(input)

NLayerDiscriminator

class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator

        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)


PixelDiscriminator

class PixelDiscriminator(nn.Module):
    """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""

    def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
        """Construct a 1x1 PatchGAN discriminator

        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            norm_layer      -- normalization layer
        """
        super(PixelDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        self.net = [
            nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
            norm_layer(ndf * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]

        self.net = nn.Sequential(*self.net)

    def forward(self, input):
        """Standard forward."""
        return self.net(input)