Project page | Paper | GTC 2019 demo | Youtube Demo of GauGAN |
Installation
Clone this repo.
git clone https://github.com/NVlabs/SPADE.git
cd SPADE/
This code requires PyTorch 1.0 and python 3+. Please install dependencies by
pip install -r requirements.txt
This code also requires the Synchronized-BatchNorm-PyTorch rep.
cd models/networks/
git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
cp Synchronized-BatchNorm-PyTorch/sync_batchnorm . -rf
cd ../../
Generating Images Using Pretrained Model
Once the dataset is ready. The result images can be generated using pretrained models.
-
Download the tar of the pretrained models from the Google Drive Folder, save it in ‘checkpoints/’, and run
cd checkpoints tar xvf checkpoints.tar.gz cd ../
- Generate images using the pretrained model.
python test.py --name [type]_pretrained --dataset_mode [dataset] --dataroot [path_to_dataset]
[type]_pretrained
is the directory name of the checkpoint file downloaded in Step 1, which should be one ofcoco_pretrained
,ade20k_pretrained
, andcityscapes_pretrained
.[dataset]
can be one ofcoco
,ade20k
, andcityscapes
, and[path_to_dataset]
, is the path to the dataset. If you are running on CPU mode, append--gpu_ids -1
. - The outputs images are stored at
./results/[type]_pretrained/
by default. You can view them using the autogenerated HTML file in the directory.
Training New Models
New models can be trained with the following commands.
-
Prepare dataset. To train on the datasets shown in the paper, you can download the datasets and use
--dataset_mode
option, which will choose which subclass ofBaseDataset
is loaded. For custom datasets, the easiest way is to use./data/custom_dataset.py
by specifying the option--dataset_mode custom
, along with--label_dir [path_to_labels] --image_dir [path_to_images]
. You also need to specify options such as--label_nc
for the number of label classes in the dataset,--contain_dontcare_label
to specify whether it has an unknown label, or--no_instance
to denote the dataset doesn’t have instance maps. -
Train.
# To train on the Facades or COCO dataset, for example.
python train.py --name [experiment_name] --dataset_mode facades --dataroot [path_to_facades_dataset]
python train.py --name [experiment_name] --dataset_mode coco --dataroot [path_to_coco_dataset]
## Testing
Testing is similar to testing pretrained models.
```bash
python test.py --name [name_of_experiment] --dataset_mode [dataset_mode] --dataroot [path_to_dataset]
Use --results_dir
to specify the output directory. --how_many
will specify the maximum number of images to generate. By default, it loads the latest checkpoint. It can be changed using --which_epoch
.
Code Structure
train.py
,test.py
: the entry point for training and testing.trainers/pix2pix_trainer.py
: harnesses and reports the progress of training.models/pix2pix_model.py
: creates the networks, and compute the lossesmodels/networks/
: defines the architecture of all modelsoptions/
: creates option lists usingargparse
package. More individuals are dynamically added in other files as well. Please see the section below.data/
: defines the class for loading images and label maps.
Options
This code repo contains many options. Some options belong to only one specific model, and some options have different default values depending on other options. To address this, the BaseOption
class dynamically loads and sets options depending on what model, network, and datasets are used. This is done by calling the static method modify_commandline_options
of various classes. It takes in theparser
of argparse
package and modifies the list of options. For example, since COCO-stuff dataset contains a special label “unknown”, when COCO-stuff dataset is used, it sets --contain_dontcare_label
automatically at data/coco_dataset.py
. You can take a look at def gather_options()
of options/base_options.py
, or models/network/__init__.py
to get a sense of how this works.
VAE-Style Training with an Encoder For Style Control and Multi-Modal Outputs
To train our model along with an image encoder to enable multi-modal outputs as in Figure 15 of the paper, please use --use_vae
. The model will create netE
in addition to netG
and netD
and train with KL-Divergence loss.
Core Design
In SPADE, the affine layer is learned from semantic segmentation map.
class SPADE(nn.Module):
def __init__(self, config_text, norm_nc, label_nc):
super().__init__()
assert config_text.startswith('spade')
parsed = re.search('spade(\D+)(\d)x\d', config_text)
param_free_norm_type = str(parsed.group(1))
ks = int(parsed.group(2))
if param_free_norm_type == 'instance':
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
elif param_free_norm_type == 'syncbatch':
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
elif param_free_norm_type == 'batch':
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
else:
raise ValueError('%s is not a recognized param-free norm type in SPADE'
% param_free_norm_type)
# The dimension of the intermediate embedding space. Yes, hardcoded.
nhidden = 128
pw = ks // 2
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
nn.ReLU()
)
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
def forward(self, x, segmap):
# Part 1. generate parameter-free normalized activations
normalized = self.param_free_norm(x)
# Part 2. produce scaling and bias conditioned on semantic map
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
# apply scale and bias
out = normalized * (1 + gamma) + beta
return out
Related
- PyTorch Code for vid2vid
- PyTorch Code for BicycleGAN
- PyTorch Code for pix2pixHD
- PyTorch Code for CycleGAN
- PyTorch Code for SPADE
- PyTorch Code for pix2pix
- Image to Image Translation(1): pix2pix, S+U, CycleGAN, UNIT, BicycleGAN, and StarGAN
- Image to Image Translation(2): pix2pixHD, MUNIT, DRIT, vid2vid, SPADE, INIT, and FUNIT
- Deep Generative Models(Part 1): Taxonomy and VAEs
- Deep Generative Models(Part 2): Flow-based Models(include PixelCNN)
- Deep Generative Models(Part 3): GANs