Unsupervised Person Image Generation with Semantic Parsing Transformation - Sijie Song - CVPR 2019

 

Info

  • Title: Unsupervised Person Image Generation with Semantic Parsing Transformation
  • Task: Image Generation
  • Author: Sijie Song, Wei Zhang, Jiaying Liu , Tao Mei
  • Arxiv: 1829
  • Published: CVPR 2019

Highlights

We propose to address the unsupervised person image generation problem. Consequently, the problem is decomposed into semantic parsing transformation (HS ) and appearance generation (HA ).

We design a delicate training schema to carefully optimize HS and HA in an end-to-end manner, which generates better semantic maps and further improves the pose-guided image generation results.

Abstract

In this paper, we address unsupervised pose-guided person image generation, which is known challenging due to non-rigid deformation. Unlike previous methods learning a rock-hard direct mapping between human bodies, we propose a new pathway to decompose the hard mapping into two more accessible subtasks, namely, semantic parsing transformation and appearance generation. Firstly, a semantic generative network is proposed to transform between semantic parsing maps, in order to simplify the non-rigid deformation learning. Secondly, an appearance generative network learns to synthesize semantic-aware textures. Thirdly, we demonstrate that training our framework in an end-to-end manner further refines the semantic maps and final results accordingly. Our method is generalizable to other semantic-aware person image generation tasks, eg, clothing texture transfer and controlled image manipulation. Experimental results demonstrate the superiority of our method on DeepFashion and Market-1501 datasets, especially in keeping the clothing attributes and better body shapes.

Motivation & Design

Semantic parsing transformation module aims to first generate a semantic map under the target pose, which provides crucial prior for the human body shape and clothing attributes. Guided by the predicted semantic map and the reference image, appearance generation module then synthesizes textures for the final output image.

Unsupervised Person Image Generation with Semantic Parsing Transformation

Unsupervised Person Image Generation with Semantic Parsing Transformation

Experiments & Ablation Study

Unsupervised Person Image Generation with Semantic Parsing Transformation

Unsupervised Person Image Generation with Semantic Parsing Transformation

Code

PyTorch

Unsupervised Person Image Generation with Semantic Parsing Transformation

Generator

class Global_Generator(nn.Module):
    def __init__(self, input_nc, pose_dim, image_size, nfilters_enc, nfilters_dec, warp_skip, use_input_pose=True):
        super(Global_Generator, self).__init__()
        self.input_nc = input_nc
        # number of skip connections
        self.num_skips = 1 if warp_skip=='None' else 2
        self.warp_skip = warp_skip
        self.pose_dim = pose_dim
        self.nfilters_dec = nfilters_dec
        self.nfilters_enc = nfilters_enc
        self.image_size = image_size
        self.use_input_pose = use_input_pose
        # input parsing result to encoder_pose
        self.encoder_app = encoder(input_nc-self.pose_dim - 9, nfilters_enc)
        self.encoder_pose = encoder(self.pose_dim + 9, nfilters_enc)

        self.decoder = decoder(nfilters_dec, nfilters_enc, self.num_skips)
        self.pose_dim = 18


    def get_imgpose(self, input, use_input_pose, pose_dim):
        inp_img = input[:, :12] # include pose and parsing
        inp_pose = input[:, 12:12 + pose_dim] if use_input_pose else None

        tg_parsing = input[:, 12+pose_dim: 21+pose_dim] # target parsing
        tg_pose_index = 21 + pose_dim if use_input_pose else 6
        tg_pose = input[:, tg_pose_index:]

        return inp_img, inp_pose, tg_parsing, tg_pose

    def forward(self, input, warps, masks, fade_in_app, fade_in_pose, fade_in_alpha):

        inp_app, inp_pose, tg_parsing, tg_pose = self.get_imgpose(input, self.use_input_pose, self.pose_dim)
        inp_app = torch.cat([inp_app, inp_pose], dim=1)

        #fade in the feat from high resolution image
        skips_app = self.encoder_app(inp_app, fade_in_app, fade_in_alpha)
        #len(enc_filter), enc_c, h, w

        inp_pose = torch.cat([tg_pose, tg_parsing], dim=1)

        #fade in the feat from high resolution image        
        skips_pose = self.encoder_pose(inp_pose, fade_in_pose, fade_in_alpha)
        #len(enc_filter), enc_c, h, w

        # define concatenate func
        skips = self.concatenate_skips(skips_app, skips_pose, warps, masks)
        out, feat = self.decoder(skips)

        # return out and skips for local generator
        return out, feat, skips

    def concatenate_skips(self, skips_app, skips_pose, warps, masks):
        skips = []
        for i, (sk_app, sk_pose) in enumerate(zip(skips_app, skips_pose)):
            if i < 4:
                out = AffineTransformLayer(10 if self.warp_skip == 'mask' else 1, self.image_size, self.warp_skip)(sk_app, warps, masks)
                out = torch.cat([out, sk_pose], dim=1)
            else:
                out = torch.cat([sk_app, sk_pose], dim=1)
            skips.append(out)
        return skips


class Local_Generator(nn.Module):
    def __init__(self, input_nc, pose_dim, image_size, nfilters_enc, nfilters_dec, warp_skip, use_input_pose=True):
        super(Local_Generator, self).__init__()
        self.input_nc = input_nc
        # number of skip connections
        self.num_skips = 1 if warp_skip=='None' else 2
        self.warp_skip = warp_skip
        self.pose_dim = pose_dim
        self.nfilters_dec = nfilters_dec
        self.nfilters_enc = nfilters_enc
        self.image_size = image_size
        self.use_input_pose = use_input_pose
        self.pose_dim = 18

        # build global_generator

        ###### global generator model #####    
        self.model_global= Global_Generator(self.input_nc, self.pose_dim, (128,128), self.nfilters_enc, self.nfilters_dec, self.warp_skip, self.use_input_pose)
        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)


        # local enhance layers
        self.model_local_encoder_app = local_encoder(input_nc-self.pose_dim - 9, nfilters_enc[:2])
        self.model_local_encoder_pose = local_encoder(self.pose_dim + 9, nfilters_enc[:2])
        self.model_local_decoder = local_decoder(nfilters_dec[-2:], nfilters_enc, self.num_skips)


    def get_imgpose(self, input, use_input_pose, pose_dim):
        inp_img = input[:, :12] # include pose and parsing
        inp_pose = input[:, 12:12 + pose_dim] if use_input_pose else None

        tg_parsing = input[:, 12+pose_dim: 21+pose_dim] # target parsing
        tg_pose_index = 21 + pose_dim if use_input_pose else 6
        tg_pose = input[:, tg_pose_index:]

        return inp_img, inp_pose, tg_parsing, tg_pose

    def forward(self, input, down_input, warps, masks, warps_128, masks_128,fade_in_alpha):

        inp_app, inp_pose, tg_parsing, tg_pose = self.get_imgpose(input, self.use_input_pose, self.pose_dim)

        inp_app = torch.cat([inp_app, inp_pose], dim=1)
        local_skips_app = self.model_local_encoder_app(inp_app)
        #skips_app:[32 x 256 x 256, 64 x 128 x 128]


        inp_pose = torch.cat([tg_pose, tg_parsing], dim=1)
        local_skips_pose = self.model_local_encoder_pose(inp_pose)
        #skips_pose: [32 x 256 x 256, 64 x 128 x 128]

        # define concatenate func
        local_skips = self.concatenate_skips(local_skips_app, local_skips_pose, warps, masks)
        # local_skips: [(32 + 32) x 256 x 256, (64 + 64) x 128 x 128]

        # downsample input to feed global_generator
        global_output, global_feat, global_skips = self.model_global(down_input, warps_128, masks_128, local_skips_app[1], local_skips_pose[1], fade_in_alpha) 
        # 3 x 256 x 256, 128 x 128 x 128, [(64 + 64) x 128 x 128, ...]

        # Concate the output of global skips and global output
        local_skips[1] = torch.cat([global_feat,global_skips[0]], dim=1)
        #local_skips: [(32 + 32) x 256 x 256, (128 + 64 + 64) x 128 x 128]

        out = self.model_local_decoder(local_skips)

        out = fade_in_alpha * out + (1-fade_in_alpha) * global_output
        return out

    def concatenate_skips(self, skips_app, skips_pose, warps, masks):
        skips = []
        for i, (sk_app, sk_pose) in enumerate(zip(skips_app, skips_pose)):
            out = AffineTransformLayer(10 if self.warp_skip == 'mask' else 1, self.image_size, self.warp_skip)(sk_app, warps, masks)
            out = torch.cat([out, sk_pose], dim=1)

            skips.append(out)
        return skips

Backward Process

 def optimize_parameters(self):
    # forward
    self.forward()

    # freeze the pose detector
    self.set_requires_grad([self.netpose_det], False)
    self.set_requires_grad([self.downsample], False)

    # G
    self.set_requires_grad([self.netD, self.netD_face], False)
    self.optimizer_G.zero_grad()
    torch.nn.utils.clip_grad_norm(self.netG.parameters(), 100)
    self.backward_G()
    self.optimizer_G.step()

    # D
    self.set_requires_grad([self.netD, self.netD_face], True)
    self.optimizer_D.zero_grad()
    self.backward_D()
    self.optimizer_D.step()

    def backward_pose_det(self):

    self.loss_pose_det_A = self.criterionPose(self.fake_A_pose, self.pose_A) * 700 
    self.loss_pose_det_B = self.criterionPose(self.fake_B_pose, self.pose_B) * 700
    loss_pose_det = self.loss_pose_det_B + self.loss_pose_det_A
    return loss_pose_det


def backward_D_basic(self, netD, real_img, fake_img, real_parsing, fake_parsing):
    # Real
    real = torch.cat((real_img, real_parsing), 1)
    pred_real = netD(real)
    loss_D_real = self.criterionGAN(pred_real, True)

    # Fake
    fake = torch.cat((fake_img, fake_parsing), 1)
    pred_fake = netD(fake.detach())
    loss_D_fake = self.criterionGAN(pred_fake, False)

    # Combined loss
    loss_D = (loss_D_real + loss_D_fake) * 0.5
    loss_D.backward()
    return loss_D

def backward_D_face(self, netD, real_img, fake_img):
    # Real
    pred_real = netD(real_img)
    loss_D_real = self.criterionGAN(pred_real, True)

    # Fake
    pred_fake = netD(fake_img.detach())
    loss_D_fake = self.criterionGAN(pred_fake, False)

    # Combined loss
    loss_D = (loss_D_real + loss_D_fake) * 0.5
    loss_D.backward()
    return loss_D

def backward_D(self):

    # Train the general discriminator
    self.loss_D_A = self.backward_D_basic(self.netD, self.real_A, self.fake_A, self.A_parsing, self.A_parsing)
    self.loss_D_B = self.backward_D_basic(self.netD, self.real_A, self.fake_B, self.A_parsing, self.B_parsing)

     # Train the face discriminator
    self.loss_D_A_face = self.backward_D_face(self.netD_face, self.real_A_face, self.fake_A_face)
    self.loss_D_B_face = self.backward_D_face(self.netD_face, self.real_A_face, self.fake_B_face)


def backward_G(self):

    self.D_fake_B = torch.cat((self.fake_B, self.B_parsing), 1)
    self.D_fake_A = torch.cat((self.fake_A, self.A_parsing), 1)

    # Train the general discriminator, as well as the face discriminator
    self.loss_G_A = self.criterionGAN(self.netD(self.D_fake_B), True) + self.criterionGAN(self.netD_face(self.fake_B_face), True)
    self.loss_G_B = self.criterionGAN(self.netD(self.D_fake_A), True) + self.criterionGAN(self.netD_face(self.fake_A_face), True)


    # pose consistency loss
    self.loss_pose_det = self.backward_pose_det()

    # Using content loss (L2)
    self.loss_content_loss = 0.03 * self.criterionIdt(self.fake_A_feat, self.real_A_feat)

    # semantic-aware loss
    self.loss_patch_style_real_A_fake_B = self.criterionSty(self.visibility * patch_gram_matrix(self.fake_B_feat, self.downsample_AtoB_masks), 
                                                            self.visibility * patch_gram_matrix(self.real_A_feat, self.downsample_BtoA_masks))
    self.loss_patch_style_fake_A_fake_B = (self.visibility * patch_gram_matrix(self.fake_A_feat, self.downsample_BtoA_masks) - 
                                           self.visibility * patch_gram_matrix(self.fake_B_feat, self.downsample_AtoB_masks)) ** 2
    self.loss_patch_style_fake_A_fake_B = self.loss_patch_style_fake_A_fake_B.mean()


    self.loss_patch_style = self.loss_patch_style_fake_A_fake_B + self.loss_patch_style_real_A_fake_B 

    self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_pose_det +  self.loss_content_loss + self.loss_patch_style

    self.loss_G.backward()