PyTorch Code for pix2pix

 

Info

Title: Image-to-Image Translation with Conditional Adversarial Networks

PyTorch Code Project Paper Torch Note

Prerequisites

  • Linux or macOS
  • Python 3
  • CPU or NVIDIA GPU + CUDA CuDNN

Getting Started

Installation

  • Clone this repo:
    git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
    cd pytorch-CycleGAN-and-pix2pix
    
  • Install PyTorch 0.4+ and other dependencies (e.g., torchvision, visdom and dominate).
    • For pip users, please type the command pip install -r requirements.txt.
    • For Conda users, we provide a installation script ./scripts/conda_deps.sh. Alternatively, you can create a new Conda environment using conda env create -f environment.yml.
    • For Docker users, we provide the pre-built Docker image and Dockerfile. Please refer to our Docker page.

train/test

  • Download a pix2pix dataset (e.g.facades):
    bash ./datasets/download_pix2pix_dataset.sh facades
    
  • Train a model:
    #!./scripts/train_pix2pix.sh
    python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
    
  • To view training results and loss plots, run python -m visdom.server and click the URL http://localhost:8097. To see more intermediate results, check out ./checkpoints/facades_pix2pix/web/index.html.

  • Test the model (bash ./scripts/test_pix2pix.sh):
    #!./scripts/test_pix2pix.sh
    python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
    
  • The test results will be saved to a html file here: ./results/facades_pix2pix/test_latest/index.html. You can find more scripts at scripts directory.
  • To train and test pix2pix-based colorization models, please add --model colorization and --dataset_mode colorization. See our training tips for more details.

Apply a pre-trained model

Download a pre-trained model with ./scripts/download_pix2pix_model.sh.

  • Check here for all the available pix2pix models. For example, if you would like to download label2photo model on the Facades dataset,
    bash ./scripts/download_pix2pix_model.sh facades_label2photo
    
  • Download the pix2pix facades datasets:
    bash ./datasets/download_pix2pix_dataset.sh facades
    
  • Then generate the results using
    python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_label2photo_pretrained
    
  • Note that we specified --direction BtoA as Facades dataset’s A to B direction is photos to labels.

  • If you would like to apply a pre-trained model to a collection of input images (rather than image pairs), please use --model test option. See ./scripts/test_single.sh for how to apply a model to Facade label maps (stored in the directory facades/testB).

  • See a list of currently available models at ./scripts/download_pix2pix_model.sh

Training/test Tips

Training/test options

Please see options/train_options.py and options/base_options.py for the training flags; see options/test_options.py and options/base_options.py for the test flags. There are some model-specific flags as well, which are added in the model files, such as --lambda_A option in model/cycle_gan_model.py. The default values of these options are also adjusted in the model files.

CPU/GPU (default --gpu_ids 0)

Please set--gpu_ids -1 to use CPU mode; set --gpu_ids 0,1,2 for multi-GPU mode. You need a large batch size (e.g., --batch_size 32) to benefit from multiple GPUs.

Visualization

During training, the current results can be viewed using two methods. First, if you set --display_id > 0, the results and loss plot will appear on a local graphics web server launched by visdom. To do this, you should have visdom installed and a server running by the command python -m visdom.server. The default server URL is http://localhost:8097. display_id corresponds to the window ID that is displayed on the visdom server. The visdom display functionality is turned on by default. To avoid the extra overhead of communicating with visdom set --display_id -1. Second, the intermediate results are saved to [opt.checkpoints_dir]/[opt.name]/web/ as an HTML file. To avoid this, set --no_html.

Preprocessing

Images can be resized and cropped in different ways using --preprocess option. The default option 'resize_and_crop' resizes the image to be of size (opt.load_size, opt.load_size) and does a random crop of size (opt.crop_size, opt.crop_size). 'crop' skips the resizing step and only performs random cropping. 'scale_width' resizes the image to have width opt.crop_size while keeping the aspect ratio. 'scale_width_and_crop' first resizes the image to have width opt.load_size and then does random cropping of size (opt.crop_size, opt.crop_size). 'none' tries to skip all these preprocessing steps. However, if the image size is not a multiple of some number depending on the number of downsamplings of the generator, you will get an error because the size of the output image may be different from the size of the input image. Therefore, 'none' option still tries to adjust the image size to be a multiple of 4. You might need a bigger adjustment if you change the generator architecture. Please see data/base_datset.py do see how all these were implemented.

Fine-tuning/resume training

To fine-tune a pre-trained model, or resume the previous training, use the --continue_train flag. The program will then load the model based on epoch. By default, the program will initialize the epoch count as 1. Set --epoch_count <int> to specify a different starting epoch count.

Prepare your own datasets for CycleGAN

You need to create two directories to host images from domain A /path/to/data/trainA and from domain B /path/to/data/trainB. Then you can train the model with the dataset flag --dataroot /path/to/data. Optionally, you can create hold-out test datasets at /path/to/data/testA and /path/to/data/testB to test your model on unseen images.

Prepare your own datasets for pix2pix

Pix2pix’s training requires paired data. We provide a python script to generate training data in the form of pairs of images {A,B}, where A and B are two different depictions of the same underlying scene. For example, these might be pairs {label map, photo} or {bw image, color image}. Then we can learn to translate A to B or B to A:

Create folder /path/to/data with subdirectories A and B. A and B should each have their own subdirectories train, val, test, etc. In /path/to/data/A/train, put training images in style A. In /path/to/data/B/train, put the corresponding images in style B. Repeat same for other data splits (val, test, etc).

Corresponding images in a pair {A,B} must be the same size and have the same filename, e.g., /path/to/data/A/train/1.jpg is considered to correspond to /path/to/data/B/train/1.jpg.

Once the data is formatted this way, call:

python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data

This will combine each pair of images (A,B) into a single image file, ready for training.

About image size

Since the generator architecture in CycleGAN involves a series of downsampling / upsampling operations, the size of the input and output image may not match if the input image size is not a multiple of 4. As a result, you may get a runtime error because the L1 identity loss cannot be enforced with images of different size. Therefore, we slightly resize the image to become multiples of 4 even with --preprocess none option. For the same reason, --crop_size needs to be a multiple of 4.

Training/Testing with high res images

CycleGAN is quite memory-intensive as four networks (two generators and two discriminators) need to be loaded on one GPU, so a large image cannot be entirely loaded. In this case, we recommend training with cropped images. For example, to generate 1024px results, you can train with --preprocess scale_width_and_crop --load_size 1024 --crop_size 360, and test with --preprocess scale_width --load_size 1024. This way makes sure the training and test will be at the same scale. At test time, you can afford higher resolution because you don’t need to load all networks.

About loss curve

Unfortunately, the loss curve does not reveal much information in training GANs, and CycleGAN is no exception. To check whether the training has converged or not, we recommend periodically generating a few samples and looking at them.

About batch size

For all experiments in the paper, we set the batch size to be 1. If there is room for memory, you can use higher batch size with batch norm or instance norm. (Note that the default batchnorm does not work well with multi-GPU training. You may consider using synchronized batchnorm instead). But please be aware that it can impact the training. In particular, even with Instance Normalization, different batch sizes can lead to different results. Moreover, increasing --crop_size may be a good alternative to increasing the batch size.

Notes on Colorization

No need to run combine_A_and_B.py for colorization. Instead, you need to prepare natural images and set --dataset_mode colorization and --model colorization in the script. The program will automatically convert each RGB image into Lab color space, and create L -> ab image pair during the training. Also set --input_nc 1 and --output_nc 2. The training and test directory should be organized as /your/data/train and your/data/test. See example scripts scripts/train_colorization.sh and scripts/test_colorization for more details.

Notes on Extracting Edges

We provide python and Matlab scripts to extract coarse edges from photos. Run scripts/edges/batch_hed.py to compute HED edges. Run scripts/edges/PostprocessHED.m to simplify edges with additional post-processing steps. Check the code documentation for more details.

Evaluating Labels2Photos on Cityscapes

We provide scripts for running the evaluation of the Labels2Photos task on the Cityscapes validation set. We assume that you have installed caffe (and pycaffe) in your system. If not, see the official website for installation instructions. Once caffe is successfully installed, download the pre-trained FCN-8s semantic segmentation model (512MB) by running

bash ./scripts/eval_cityscapes/download_fcn8s.sh

Then make sure ./scripts/eval_cityscapes/ is in your system’s python path. If not, run the following command to add it

export PYTHONPATH=${PYTHONPATH}:./scripts/eval_cityscapes/

Now you can run the following command to evaluate your predictions:

python ./scripts/eval_cityscapes/evaluate.py --cityscapes_dir /path/to/original/cityscapes/dataset/ --result_dir /path/to/your/predictions/ --output_dir /path/to/output/directory/

Images stored under --result_dir should contain your model predictions on the Cityscapes validation split, and have the original Cityscapes naming convention (e.g., frankfurt_000001_038418_leftImg8bit.png). The script will output a text file under --output_dir containing the metric.

Core Design

pix2pix PyTorch

Model

import torch
from .base_model import BaseModel
from . import networks


class Pix2PixModel(BaseModel):
    """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.

    The model training requires '--dataset_mode aligned' dataset.
    By default, it uses a '--netG unet256' U-Net generator,
    a '--netD basic' discriminator (PatchGAN),
    and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).

    pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
    """

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG(self.real_A)  # G(A)

    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; stop backprop to the generator by detaching fake_B
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)  # we use conditional GANs; we need to feed both input and output to the discriminator
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
        # combine loss and calculate gradients
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

ResnetGenerator

class ResnetGenerator(nn.Module):
    """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.

    We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
    """

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
        """Construct a Resnet-based generator

        Parameters:
            input_nc (int)      -- the number of channels in input images
            output_nc (int)     -- the number of channels in output images
            ngf (int)           -- the number of filters in the last conv layer
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers
            n_blocks (int)      -- the number of ResNet blocks
            padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
        """
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):       # add ResNet blocks

            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

        for i in range(n_downsampling):  # add upsampling layers
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        """Standard forward"""
        return self.model(input)


UnetGenerator

class UnetGenerator(nn.Module):
    """Create a Unet-based generator"""

    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
        """Construct a Unet generator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            output_nc (int) -- the number of channels in output images
            num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
                                image of size 128x128 will become of size 1x1 # at the bottleneck
            ngf (int)       -- the number of filters in the last conv layer
            norm_layer      -- normalization layer

        We construct the U-Net from the innermost layer to the outermost layer.
        It is a recursive process.
        """
        super(UnetGenerator, self).__init__()
        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layer
        for i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        # gradually reduce the number of filters from ngf * 8 to ngf
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layer

    def forward(self, input):
        """Standard forward"""
        return self.model(input)

NLayerDiscriminator

class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator

        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)


PixelDiscriminator

class PixelDiscriminator(nn.Module):
    """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""

    def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
        """Construct a 1x1 PatchGAN discriminator

        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            norm_layer      -- normalization layer
        """
        super(PixelDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        self.net = [
            nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
            norm_layer(ndf * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]

        self.net = nn.Sequential(*self.net)

    def forward(self, input):
        """Standard forward."""
        return self.net(input)