Info
Title: High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs
PyTorch Code | Project | Youtube | Paper | Note |
Prerequisites
- Linux or macOS
- Python 2 or 3
- NVIDIA GPU (11G memory or larger) + CUDA cuDNN
Getting Started
Installation
- Install PyTorch and dependencies from http://pytorch.org
- Install python libraries dominate.
pip install dominate
- Clone this repo:
git clone https://github.com/NVIDIA/pix2pixHD cd pix2pixHD
Testing
- A few example Cityscapes test images are included in the
datasets
folder. - Please download the pre-trained Cityscapes model from here (google drive link), and put it under
./checkpoints/label2city_1024p/
- Test the model (
bash ./scripts/test_1024p.sh
):#!./scripts/test_1024p.sh python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none
The test results will be saved to a html file here:
./results/label2city_1024p/test_latest/index.html
.
More example scripts can be found in the scripts
directory.
Dataset
- We use the Cityscapes dataset. To train a model on the full dataset, please download it from the official website (registration required).
After downloading, please put it under the
datasets
folder in the same way the example images are provided.
Training
- Train a model at 1024 x 512 resolution (
bash ./scripts/train_512p.sh
):#!./scripts/train_512p.sh python train.py --name label2city_512p
- To view training results, please checkout intermediate results in
./checkpoints/label2city_512p/web/index.html
. If you have tensorflow installed, you can see tensorboard logs in./checkpoints/label2city_512p/logs
by adding--tf_log
to the training scripts.
Multi-GPU training
- Train a model using multiple GPUs (
bash ./scripts/train_512p_multigpu.sh
):#!./scripts/train_512p_multigpu.sh python train.py --name label2city_512p --batchSize 8 --gpu_ids 0,1,2,3,4,5,6,7
Note: this is not tested and we trained our model using single GPU only. Please use at your own discretion.
Training with Automatic Mixed Precision (AMP) for faster speed
- To train with mixed precision support, please first install apex from: https://github.com/NVIDIA/apex
- You can then train the model by adding
--fp16
. For example,#!./scripts/train_512p_fp16.sh python -m torch.distributed.launch train.py --name label2city_512p --fp16
In our test case, it trains about 80% faster with AMP on a Volta machine.
Core Design
class Pix2PixHDModel(BaseModel):
def forward(self, label, inst, image, feat, infer=False):
# Encode Inputs
input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
# Fake Generation
if self.use_features:
if not self.opt.load_features:
feat_map = self.netE.forward(real_image, inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
fake_image = self.netG.forward(input_concat)
# Fake Detection and Loss
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(input_label, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
# GAN feature matching loss
loss_G_GAN_Feat = 0
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i])-1):
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
# VGG feature matching loss
loss_G_VGG = 0
if not self.opt.no_vgg_loss:
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
# Only return the fake_B image if necessary to save BW
return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]
encode_input & discriminate
def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
if self.opt.label_nc == 0:
input_label = label_map.data.cuda()
else:
# create one-hot vector for label map
size = label_map.size()
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
if self.opt.data_type == 16:
input_label = input_label.half()
# get edges from instance map
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map)
input_label = torch.cat((input_label, edge_map), dim=1)
input_label = Variable(input_label, volatile=infer)
# real images for training
if real_image is not None:
real_image = Variable(real_image.data.cuda())
# instance map for feature encoding
if self.use_features:
# get precomputed feature maps
if self.opt.load_features:
feat_map = Variable(feat_map.data.cuda())
if self.opt.label_feat:
inst_map = label_map.cuda()
return input_label, inst_map, real_image, feat_map
def discriminate(self, input_label, test_image, use_pool=False):
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
if use_pool:
fake_query = self.fake_pool.query(input_concat)
return self.netD.forward(fake_query)
else:
return self.netD.forward(input_concat)
precompute_feature_maps & encode_features
# /precompute_feature_maps.py
feat_map = model.module.netE.forward(Variable(data['image'].cuda(), volatile=True), data['inst'].cuda())
# /models/pix2pixHD_model.py
self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
opt.n_downsample_E, norm=opt.norm,
gpu_ids=self.gpu_ids)
# /encode_features.py
feat = model.module.encode_features(data['image'], data['inst'])
# /models/pix2pixHD_model.py
def encode_features(self, image, inst):
image = Variable(image.cuda(), volatile=True)
feat_num = self.opt.feat_num
h, w = inst.size()[2], inst.size()[3]
block_num = 32
feat_map = self.netE.forward(image, inst.cuda())
inst_np = inst.cpu().numpy().astype(int)
feature = {}
for i in range(self.opt.label_nc):
feature[i] = np.zeros((0, feat_num+1))
for i in np.unique(inst_np):
label = i if i < 1000 else i//1000
idx = (inst == int(i)).nonzero()
num = idx.size()[0]
idx = idx[num//2,:]
val = np.zeros((1, feat_num+1))
for k in range(feat_num):
val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
val[0, feat_num] = float(num) / (h * w // block_num)
feature[label] = np.append(feature[label], val, axis=0)
return feature
# /models/networks.py
def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
n_blocks_local=3, norm='instance', gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
if netG == 'global':
netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
elif netG == 'local':
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global,
n_local_enhancers, n_blocks_local, norm_layer)
elif netG == 'encoder':
netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer)
else:
raise('generator not implemented!')
print(netG)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netG.cuda(gpu_ids[0])
netG.apply(weights_init)
return netG
class Encoder(nn.Module):
def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
super(Encoder, self).__init__()
self.output_nc = output_nc
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf), nn.ReLU(True)]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), nn.ReLU(True)]
### upsample
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)), nn.ReLU(True)]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input, inst):
outputs = self.model(input)
# instance-wise average pooling
outputs_mean = outputs.clone()
inst_list = np.unique(inst.cpu().numpy().astype(int))
for i in inst_list:
for b in range(input.size()[0]):
indices = (inst[b:b+1] == int(i)).nonzero() # n x 4
for j in range(self.output_nc):
output_ins = outputs[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]]
mean_feat = torch.mean(output_ins).expand_as(output_ins)
outputs_mean[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat
return outputs_mean
Related
- PyTorch Code for vid2vid
- PyTorch Code for BicycleGAN
- PyTorch Code for pix2pixHD
- PyTorch Code for SPADE
- PyTorch Code for CycleGAN
- PyTorch Code for pix2pix
- Image to Image Translation(1): pix2pix, S+U, CycleGAN, UNIT, BicycleGAN, and StarGAN
- Image to Image Translation(2): pix2pixHD, MUNIT, DRIT, vid2vid, SPADE, INIT, and FUNIT
- Deep Generative Models(Part 1): Taxonomy and VAEs
- Deep Generative Models(Part 2): Flow-based Models(include PixelCNN)
- Deep Generative Models(Part 3): GANs