COCO-GAN: Generation by Parts via Conditional Coordinating - Chieh Hubert Lin - ICCV 2019



  • Title: COCO-GAN: Generation by Parts via Conditional Coordinating
  • Task: Image Generation
  • Author: Chieh Hubert Lin, Chia-Che Chang, Yu-Sheng Chen, Da-Cheng Juan, Wei Wei, Hwann-Tzong Chen
  • Date: Apr. 2019
  • Arxiv: 1904.00284
  • Published: ICCV 2019


  • Discriminator with auxiliary task for content consistency(latent vector z and last featuremap from D) and spatial consistency(macro patch coordinates)


Humans can only interact with part of the surrounding environment due to biological restrictions. Therefore, we learn to reason the spatial relationships across a series of observations to piece together the surrounding environment. Inspired by such behavior and the fact that machines also have computational constraints, we propose COnditional COordinate GAN (COCO-GAN) of which the generator generates images by parts based on their spatial coordinates as the condition. On the other hand, the discriminator learns to justify realism across multiple assembled patches by global coherence, local appearance, and edge-crossing continuity. Despite the full images are never generated during training, we show that COCO-GAN can produce state-of-the-art-quality full images during inference. We further demonstrate a variety of novel applications enabled by teaching the network to be aware of coordinates. First, we perform extrapolation to the learned coordinate manifold and generate off-the-boundary patches. Combining with the originally generated full image, COCO-GAN can produce images that are larger than training samples, which we called “beyond-boundary generation”. We then showcase panorama generation within a cylindrical coordinate system that inherently preserves horizontally cyclic topology. On the computation side, COCO-GAN has a built-in divide-and-conquer paradigm that reduces memory requisition during training and inference, provides high-parallelism, and can generate parts of images on-demand.

Motivation & Design

The authors propose COnditional COordinate GAN (COCO-GAN) of which the generator generates images by parts based on their spatial coordinates as the condition. On the other hand, the discriminator learns to justify realism across multiple assembled patches by global coherence, local appearance, and edge-crossing continuity.


For the COCO-GAN training, the latent vectors are duplicated multiple times, concatenated with micro coordinates, and feed to the generator to generate micro patches. Then we concatenate multiple micro patches to form a larger macro patch. The discriminator learns to discriminate between real and fake macro patches and an auxiliary task predicting the coordinate of the macro patch. Notice that none of the models requires full images during training.


During the testing phase, the micro patches generated by the generator are directly combined into a full image as the final output. Still, none of the models requires full images. Furthermore, the generated images are high-quality without any post-processing in addition to a simple concatenation.


Model Graph

g_builder = GeneratorBuilder(config)
d_builder = DiscriminatorBuilder(config)
cp_builder = SpatialPredictorBuilder(config)
zp_builder = ContentPredictorBuilder(config)

def build_graph(self):

    # Input nodes
    # Note: the input node name was wrong in the checkpoint 
    self.micro_coord_fake = tf.placeholder(tf.float32, [None, self.spatial_dim], name='micro_coord_fake')
    self.macro_coord_fake = tf.placeholder(tf.float32, [None, self.spatial_dim], name='macro_coord_fake')
    self.micro_coord_real = tf.placeholder(tf.float32, [None, self.spatial_dim], name='micro_coord_real')
    self.macro_coord_real = tf.placeholder(tf.float32, [None, self.spatial_dim], name='macro_coord_real')

    # Reversing angle for cylindrical coordinate is complicated, directly pass values here
    self.y_angle_ratio = tf.placeholder(tf.float32, [None, 1], name='y_angle_ratio') 
    self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z')

    # Real part
    self.real_macro = self.patch_handler.concat_micro_patches_gpu(
        self.real_micro, ratio_over_micro=self.ratio_macro_to_micro)
    (self.disc_real, disc_real_h) = self.d_builder(self.real_macro, self.macro_coord_real, is_training=True)
    self.c_real_pred = self.cp_builder(disc_real_h, is_training=True)
    self.z_real_pred = self.zp_builder(disc_real_h, is_training=True)

    # Fake part
    z_dup_macro = self._dup_z_for_macro(self.z)
    self.gen_micro = self.g_builder(z_dup_macro, self.micro_coord_fake, is_training=True)
    self.gen_macro = self.patch_handler.concat_micro_patches_gpu(
        self.gen_micro, ratio_over_micro=self.ratio_macro_to_micro)
    (self.disc_fake, disc_fake_h) = self.d_builder(self.gen_macro, self.macro_coord_fake, is_training=True)
    self.c_fake_pred = self.cp_builder(disc_fake_h, is_training=True)
    self.z_fake_pred = self.zp_builder(disc_fake_h, is_training=True)

    # Patch-Guided Image Generation graph
    if self._train_content_prediction_model():
        (_, disc_real_h_rec) = self.d_builder(self.real_macro, None, is_training=False)
        estim_z = self.zp_builder(disc_real_h_rec, is_training=False)
        # I didn't especially handle this.
        # if self.config["log_params"]["merge_micro_patches_in_cpu"]:
        (_, self.rec_full) = self.generate_full_image_gpu(self.z)

    print(" [Build] Composing Loss Functions ")

    print(" [Build] Creating Optimizers ")

def _compose_losses(self):

    # Content consistency loss
    self.code_loss = tf.reduce_mean(self.code_loss_w * tf.losses.absolute_difference(self.z, self.z_fake_pred))

    # Spatial consistency loss (reduce later)
    self.coord_mse_real = self.coord_loss_w * tf.losses.mean_squared_error(self.macro_coord_real, self.c_real_pred, reduction=NO_REDUCTION)
    self.coord_mse_fake = self.coord_loss_w * tf.losses.mean_squared_error(self.macro_coord_fake, self.c_fake_pred, reduction=NO_REDUCTION)

    self.coord_mse_real = tf.reduce_mean(self.coord_mse_real)
    self.coord_mse_fake = tf.reduce_mean(self.coord_mse_fake)
    self.coord_loss = self.coord_mse_real + self.coord_mse_fake

    # WGAN loss
    self.adv_real = - tf.reduce_mean(self.disc_real)
    self.adv_fake = tf.reduce_mean(self.disc_fake)
    self.d_adv_loss = self.adv_real + self.adv_fake
    self.g_adv_loss = - self.adv_fake

    # Gradient penalty loss of WGAN-GP
    gradient_penalty, self.gp_slopes = self._calc_gradient_penalty()
    self.gp_loss = self.config["loss_params"]["gp_lambda"] * gradient_penalty

    # Total loss
    self.d_loss = self.d_adv_loss + self.gp_loss + self.coord_loss + self.code_loss
    self.g_loss = self.g_adv_loss + self.coord_loss + self.code_loss
    self.q_loss = self.g_adv_loss + self.code_loss

    # Wasserstein distance for visualization
    self.w_dist = - self.adv_real - self.adv_fake

snconv2d and snlinear are operators with Spectral Normalization.


class GeneratorBuilder(Model):
    def __init__(self, config):
        self.ngf_base = self.config["model_params"]["ngf_base"]
        self.num_extra_layers = self.config["model_params"]["g_extra_layers"]
        self.micro_patch_size = self.config["data_params"]["micro_patch_size"]
        self.c_dim = self.config["data_params"]["c_dim"]
        self.update_collection = "G_update_collection"

    def _cbn(self, x, y, is_training, scope=None):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            ch = x.shape.as_list()[-1]
            gamma = slim.fully_connected(y, ch, activation_fn=None)
            beta = slim.fully_connected(y, ch, activation_fn=None)

            mean_rec = tf.get_variable("mean_recorder", [ch,], 
                initializer=tf.constant_initializer(np.zeros(ch)), trainable=False)
            var_rec  = tf.get_variable("var_recorder", [ch,], 
                initializer=tf.constant_initializer(np.ones(ch)), trainable=False)
            running_mean, running_var = tf.nn.moments(x, axes=[0, 1, 2])
            if is_training:
                new_mean_rec = 0.99 * mean_rec + 0.01 * running_mean
                new_var_rec  = 0.99 * var_rec  + 0.01 * running_var
                assign_mean_op = mean_rec.assign(new_mean_rec)
                assign_var_op  = var_rec.assign(new_var_rec)
                tf.add_to_collection(self.update_collection, assign_mean_op)
                tf.add_to_collection(self.update_collection, assign_var_op)
                mean = running_mean
                var  = running_var
                mean = mean_rec
                var  = var_rec

            # tiled_mean = tf.tile(tf.expand_dims(mean, 0), [tf.shape(x)[0], 1])
            # tiled_var  = tf.tile(tf.expand_dims(var , 0), [tf.shape(x)[0], 1])
            mean  = tf.reshape(mean, [1, 1, 1, ch])
            var   = tf.reshape(var , [1, 1, 1, ch])
            gamma = tf.reshape(gamma, [-1, 1, 1, ch])
            beta  = tf.reshape(beta , [-1, 1, 1, ch])
            out = (x-mean) / (var+_EPS) * gamma + beta
            return out 

    def _g_residual_block(self, x, y, n_ch, idx, is_training, resize=True):
        update_collection = self._get_update_collection(is_training)
        with tf.variable_scope("g_resblock_"+str(idx), reuse=tf.AUTO_REUSE):
            h = self._cbn(x, y, is_training, scope='g_resblock_cbn_1')
            h = tf.nn.relu(h)
            if resize:
                h = upscale(h, 2)
            h = snconv2d(h, n_ch, name='g_resblock_conv_1', update_collection=update_collection)
            h = self._cbn(h, y, is_training, scope='g_resblock_cbn_2')
            h = tf.nn.relu(h)
            h = snconv2d(h, n_ch, name='g_resblock_conv_2', update_collection=update_collection)

            if resize:
                sc = upscale(x, 2)
                sc = x
            sc = snconv2d(sc, n_ch, k_h=1, k_w=1, name='g_resblock_conv_sc', update_collection=update_collection)

            return h + sc

    def forward(self, z, coord, is_training):
        valid_sizes = {4, 8, 16, 32, 64, 128, 256}
        assert (self.micro_patch_size[0] in valid_sizes and self.micro_patch_size[1] in valid_sizes), \
            "I haven't test your micro patch size: {}".format(self.micro_patch_size)

        update_collection = self._get_update_collection(is_training)
        print(" [Build] Generator ; is_training: {}".format(is_training))
        with tf.variable_scope("G_generator", reuse=tf.AUTO_REUSE):
            init_sp = 2
            init_ngf_mult = 16
            cond = tf.concat([z, coord], axis=1)
            h = snlinear(cond, self.ngf_base*init_ngf_mult*init_sp*init_sp, 'g_z_fc', update_collection=update_collection)
            h = tf.reshape(h, [-1, init_sp, init_sp, self.ngf_base*init_ngf_mult])

            # Stacking residual blocks
            num_resize_layers = int(math.log(min(self.micro_patch_size), 2) - 1)
            num_total_layers  = num_resize_layers + self.num_extra_layers
            basic_layers = [8, 4, 2] 
            if num_total_layers>=len(basic_layers):
                num_replicate_layers = num_total_layers - len(basic_layers)
                ngf_mult_list = basic_layers + [1, ] * num_replicate_layers
                ngf_mult_list = basic_layers[:num_total_layers]
            print("\t ngf_mult_list = {}".format(ngf_mult_list))

            for idx, ngf_mult in enumerate(ngf_mult_list):
                n_ch = self.ngf_base * ngf_mult
                # Standard layers first
                if idx < num_resize_layers:
                    resize, is_extra = True, False
                # Extra layers do not resize spatial size
                    resize, is_extra = False, True
                h = self._g_residual_block(h, cond, n_ch, idx=idx, is_training=is_training, resize=resize)
                print("\t GResBlock: id={}, out_shape={}, resize={}, is_extra={}"
                    .format(idx, h.shape.as_list(), resize, is_extra))

            h = batch_norm(name="g_last_bn")(h, is_training=is_training)
            h = tf.nn.relu(h)
            h = snconv2d(h, self.c_dim, name='g_last_conv_2', update_collection=update_collection)
            return tf.nn.tanh(h)


class DiscriminatorBuilder(Model):
    def __init__(self, config):
        self.ndf_base = self.config["model_params"]["ndf_base"]
        self.num_extra_layers = self.config["model_params"]["d_extra_layers"]
        self.macro_patch_size = self.config["data_params"]["macro_patch_size"]

        self.update_collection = "D_update_collection"

    def _d_residual_block(self, x, out_ch, idx, is_training, resize=True, is_head=False):
        update_collection = self._get_update_collection(is_training)
        with tf.variable_scope("d_resblock_"+str(idx), reuse=tf.AUTO_REUSE):
            h = x
            if not is_head:
                h = tf.nn.relu(h)
            h = snconv2d(h, out_ch, name='d_resblock_conv_1', update_collection=update_collection)
            h = tf.nn.relu(h)
            h = snconv2d(h, out_ch, name='d_resblock_conv_2', update_collection=update_collection)
            if resize:
                h = slim.avg_pool2d(h, [2, 2])

            # Short cut
            s = x
            if resize:
                s = slim.avg_pool2d(s, [2, 2])
            s = snconv2d(s, out_ch, k_h=1, k_w=1, name='d_resblock_conv_sc', update_collection=update_collection)
            return h + s
    def forward(self, x, y=None, is_training=True):
        valid_sizes = {8, 16, 32, 64, 128, 256, 512}
        assert (self.macro_patch_size[0] in valid_sizes and self.macro_patch_size[1] in valid_sizes), \
            "I haven't test your macro patch size: {}".format(self.macro_patch_size)

        update_collection = self._get_update_collection(is_training)
        print(" [Build] Discriminator ; is_training: {}".format(is_training))
        with tf.variable_scope("D_discriminator", reuse=tf.AUTO_REUSE):

            num_resize_layers = int(math.log(min(self.macro_patch_size), 2) - 1)
            num_total_layers  = num_resize_layers + self.num_extra_layers
            basic_layers = [2, 4, 8, 8]
            if num_total_layers>=len(basic_layers):
                num_replicate_layers = num_total_layers - len(basic_layers)
                ndf_mult_list = [1, ] * num_replicate_layers + basic_layers
                ndf_mult_list = basic_layers[-num_total_layers:]
            print("\t ndf_mult_list = {}".format(ndf_mult_list))

            # Stack extra layers without resize first
            h = x
            for idx, ndf_mult in enumerate(ndf_mult_list):
                n_ch = self.ndf_base * ndf_mult
                # Head is fixed and goes first
                if idx==0:
                    is_head, resize, is_extra = True, True, False
                # Extra layers before standard layers
                elif idx<=self.num_extra_layers:
                    is_head, resize, is_extra = False, False, True
                # Last standard layer has no resize
                elif idx==len(ndf_mult_list)-1:
                    is_head, resize, is_extra = False, False, False
                # Standard layers
                    is_head, resize, is_extra = False, True, False
                h = self._d_residual_block(h, n_ch, idx=idx, is_training=is_training, resize=resize, is_head=is_head)
                print("\t DResBlock: id={}, out_shape={}, resize={}, is_extra={}"
                    .format(idx, h.shape.as_list(), resize, is_extra))

            h = tf.nn.relu(h)
            h = tf.reduce_sum(h, axis=[1,2]) # Global pooling
            last_feature_map = h
            adv_out = snlinear(h, 1, 'main_steam_out', update_collection=update_collection)

            # Projection Discriminator
            if y is not None:
                h_num_ch = self.ndf_base*ndf_mult_list[-1]
                y_emb = snlinear(y, h_num_ch, 'y_emb', update_collection=update_collection)
                proj_out = tf.reduce_sum(y_emb*h, axis=1, keepdims=True)
                proj_out = 0

            out = adv_out + proj_out
            return out, last_feature_map

Spatial Prediction

class SpatialPredictorBuilder(Model):
    def __init__(self, config):
        self.aux_dim = config["model_params"]["aux_dim"]
        self.spatial_dim = config["model_params"]["spatial_dim"]
        self.update_collection = "GD_update_collection"
    def forward(self, h, is_training):
        print(" [Build] Spatial Predictor ; is_training: {}".format(is_training))
        update_collection = self._get_update_collection(is_training)
        with tf.variable_scope("GD_spatial_prediction_head", reuse=tf.AUTO_REUSE):
            h = snlinear(h, self.aux_dim, 'fc1', update_collection=update_collection)
            h = batch_norm(name='bn1')(h, is_training=is_training)
            h = lrelu(h)
            h = snlinear(h, self.spatial_dim, 'fc2', update_collection=update_collection)
            return tf.nn.tanh(h)

Content Prediction

class ContentPredictorBuilder(Model):
    def __init__(self, config):
        self.z_dim = config["model_params"]["z_dim"]
        self.aux_dim = config["model_params"]["aux_dim"]
        self.update_collection = "Q_update_collection"
    def forward(self, h, is_training):
        print(" [Build] Spatial Predictor ; is_training: {}".format(is_training))
        update_collection = self._get_update_collection(is_training)
        with tf.variable_scope("Q_content_prediction_head", reuse=tf.AUTO_REUSE):
            h = snlinear(h, self.aux_dim, 'fc1', update_collection=update_collection)
            h = batch_norm(name='bn1')(h, is_training=is_training)
            h = lrelu(h)
            h = snlinear(h, self.z_dim, 'fc2', update_collection=update_collection)
            return tf.nn.tanh(h)