Source code for imaginaire.generators.spade

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import functools
import math
import types

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Upsample as NearestUpsample

from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock
from imaginaire.utils.data import (get_crop_h_w,
                                   get_paired_input_image_channel_number,
                                   get_paired_input_label_channel_number)
from imaginaire.utils.distributed import master_only_print as print


[docs]class Generator(nn.Module): r"""SPADE generator constructor. Args: gen_cfg (obj): Generator definition part of the yaml config file. data_cfg (obj): Data definition part of the yaml config file. """ def __init__(self, gen_cfg, data_cfg): super(Generator, self).__init__() print('SPADE generator initialization.') # We assume the first datum is the ground truth image. image_channels = getattr(gen_cfg, 'image_channels', None) if image_channels is None: image_channels = get_paired_input_image_channel_number(data_cfg) num_labels = getattr(gen_cfg, 'num_labels', None) if num_labels is None: # Calculate number of channels in the input label when not specified. num_labels = get_paired_input_label_channel_number(data_cfg) crop_h, crop_w = get_crop_h_w(data_cfg.train.augmentations) # Build the generator out_image_small_side_size = crop_w if crop_w < crop_h else crop_h num_filters = getattr(gen_cfg, 'num_filters', 128) kernel_size = getattr(gen_cfg, 'kernel_size', 3) weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral') cond_dims = 0 # Check whether we use the style code. style_dims = getattr(gen_cfg, 'style_dims', None) self.style_dims = style_dims if style_dims is not None: print('\tStyle code dimensions: %d' % style_dims) cond_dims += style_dims self.use_style = True else: self.use_style = False # Check whether we use the attribute code. if hasattr(gen_cfg, 'attribute_dims'): self.use_attribute = True self.attribute_dims = gen_cfg.attribute_dims cond_dims += gen_cfg.attribute_dims else: self.use_attribute = False if not self.use_style and not self.use_attribute: self.use_style_encoder = False else: self.use_style_encoder = True print('\tBase filter number: %d' % num_filters) print('\tConvolution kernel size: %d' % kernel_size) print('\tWeight norm type: %s' % weight_norm_type) skip_activation_norm = \ getattr(gen_cfg, 'skip_activation_norm', True) activation_norm_params = getattr(gen_cfg, 'activation_norm_params', None) if activation_norm_params is None: activation_norm_params = types.SimpleNamespace() if not hasattr(activation_norm_params, 'num_filters'): setattr(activation_norm_params, 'num_filters', 128) if not hasattr(activation_norm_params, 'kernel_size'): setattr(activation_norm_params, 'kernel_size', 3) if not hasattr(activation_norm_params, 'activation_norm_type'): setattr(activation_norm_params, 'activation_norm_type', 'sync_batch') if not hasattr(activation_norm_params, 'separate_projection'): setattr(activation_norm_params, 'separate_projection', False) if not hasattr(activation_norm_params, 'activation_norm_params'): activation_norm_params.activation_norm_params = types.SimpleNamespace() activation_norm_params.activation_norm_params.affine = True setattr(activation_norm_params, 'cond_dims', num_labels) if not hasattr(activation_norm_params, 'weight_norm_type'): setattr(activation_norm_params, 'weight_norm_type', weight_norm_type) global_adaptive_norm_type = getattr(gen_cfg, 'global_adaptive_norm_type', 'sync_batch') use_posenc_in_input_layer = getattr(gen_cfg, 'use_posenc_in_input_layer', True) output_multiplier = getattr(gen_cfg, 'output_multiplier', 1.0) print(activation_norm_params) self.spade_generator = SPADEGenerator(num_labels, out_image_small_side_size, image_channels, num_filters, kernel_size, cond_dims, activation_norm_params, weight_norm_type, global_adaptive_norm_type, skip_activation_norm, use_posenc_in_input_layer, self.use_style_encoder, output_multiplier) if self.use_style: # Build the encoder. style_enc_cfg = getattr(gen_cfg, 'style_enc', None) if style_enc_cfg is None: style_enc_cfg = types.SimpleNamespace() if not hasattr(style_enc_cfg, 'num_filters'): setattr(style_enc_cfg, 'num_filters', 128) if not hasattr(style_enc_cfg, 'kernel_size'): setattr(style_enc_cfg, 'kernel_size', 3) if not hasattr(style_enc_cfg, 'weight_norm_type'): setattr(style_enc_cfg, 'weight_norm_type', weight_norm_type) setattr(style_enc_cfg, 'input_image_channels', image_channels) setattr(style_enc_cfg, 'style_dims', style_dims) self.style_encoder = StyleEncoder(style_enc_cfg) self.z = None print('Done with the SPADE generator initialization.')
[docs] def forward(self, data, random_style=False): r"""SPADE Generator forward. Args: data (dict): - images (N x C1 x H x W tensor) : Ground truth images - label (N x C2 x H x W tensor) : Semantic representations - z (N x style_dims tensor): Gaussian random noise - random_style (bool): Whether to sample a random style vector. Returns: (dict): - fake_images (N x 3 x H x W tensor): fake images - mu (N x C1 tensor): mean vectors - logvar (N x C1 tensor): log-variance vectors """ if self.use_style_encoder: if random_style: bs = data['label'].size(0) z = torch.randn( bs, self.style_dims, dtype=torch.float32).cuda() if (data['label'].dtype == data['label'].dtype == torch.float16): z = z.half() mu = None logvar = None else: mu, logvar, z = self.style_encoder(data['images']) if self.use_attribute: data['z'] = torch.cat((z, data['attributes'].squeeze(1)), dim=1) else: data['z'] = z output = self.spade_generator(data) if self.use_style_encoder: output['mu'] = mu output['logvar'] = logvar return output
[docs] def inference(self, data, random_style=False, use_fixed_random_style=False, keep_original_size=False): r"""Compute results images for a batch of input data and save the results in the specified folder. Args: data (dict): - images (N x C1 x H x W tensor) : Ground truth images - label (N x C2 x H x W tensor) : Semantic representations - z (N x style_dims tensor): Gaussian random noise random_style (bool): Whether to sample a random style vector. use_fixed_random_style (bool): Sample random style once and use it for all the remaining inference. keep_original_size (bool): Keep original size of the input. Returns: (dict): - fake_images (N x 3 x H x W tensor): fake images - mu (N x C1 tensor): mean vectors - logvar (N x C1 tensor): log-variance vectors """ self.eval() self.spade_generator.eval() if self.use_style_encoder: if random_style and self.use_style_encoder: if self.z is None or not use_fixed_random_style: bs = data['label'].size(0) z = torch.randn( bs, self.style_dims, dtype=torch.float32).to('cuda') if (data['label'].dtype == data['label'].dtype == torch.float16): z = z.half() self.z = z else: z = self.z else: mu, logvar, z = self.style_encoder(data['images']) data['z'] = z output = self.spade_generator(data) output_images = output['fake_images'] if keep_original_size: height = data['original_h_w'][0][0] width = data['original_h_w'][0][1] output_images = torch.nn.functional.interpolate( output_images, size=[height, width]) for key in data['key'].keys(): if 'segmaps' in key or 'seg_maps' in key: file_names = data['key'][key][0] break for key in data['key'].keys(): if 'edgemaps' in key or 'edge_maps' in key: file_names = data['key'][key][0] break return output_images, file_names
[docs]class SPADEGenerator(nn.Module): r"""SPADE Image Generator constructor. Args: num_labels (int): Number of different labels. out_image_small_side_size (int): min(width, height) image_channels (int): Num. of channels of the output image. num_filters (int): Base filter numbers. kernel_size (int): Convolution kernel size. style_dims (int): Dimensions of the style code. activation_norm_params (obj): Spatially adaptive normalization param. weight_norm_type (str): Type of weight normalization. ``'none'``, ``'spectral'``, or ``'weight'``. global_adaptive_norm_type (str): Type of normalization in SPADE. skip_activation_norm (bool): If ``True``, applies activation norm to the shortcut connection in residual blocks. use_style_encoder (bool): Whether to use global adaptive norm like conditional batch norm or adaptive instance norm. output_multiplier (float): A positive number multiplied to the output """ def __init__(self, num_labels, out_image_small_side_size, image_channels, num_filters, kernel_size, style_dims, activation_norm_params, weight_norm_type, global_adaptive_norm_type, skip_activation_norm, use_posenc_in_input_layer, use_style_encoder, output_multiplier): super(SPADEGenerator, self).__init__() self.output_multiplier = output_multiplier self.use_style_encoder = use_style_encoder self.use_posenc_in_input_layer = use_posenc_in_input_layer self.out_image_small_side_size = out_image_small_side_size self.num_filters = num_filters padding = int(np.ceil((kernel_size - 1.0) / 2)) nonlinearity = 'leakyrelu' activation_norm_type = 'spatially_adaptive' base_res2d_block = \ functools.partial(Res2dBlock, kernel_size=kernel_size, padding=padding, bias=[True, True, False], weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, activation_norm_params=activation_norm_params, skip_activation_norm=skip_activation_norm, nonlinearity=nonlinearity, order='NACNAC') if self.use_style_encoder: self.fc_0 = LinearBlock(style_dims, 2 * style_dims, weight_norm_type=weight_norm_type, nonlinearity='relu', order='CAN') self.fc_1 = LinearBlock(2 * style_dims, 2 * style_dims, weight_norm_type=weight_norm_type, nonlinearity='relu', order='CAN') adaptive_norm_params = types.SimpleNamespace() if not hasattr(adaptive_norm_params, 'cond_dims'): setattr(adaptive_norm_params, 'cond_dims', 2 * style_dims) if not hasattr(adaptive_norm_params, 'activation_norm_type'): setattr(adaptive_norm_params, 'activation_norm_type', global_adaptive_norm_type) if not hasattr(adaptive_norm_params, 'weight_norm_type'): setattr(adaptive_norm_params, 'weight_norm_type', activation_norm_params.weight_norm_type) if not hasattr(adaptive_norm_params, 'separate_projection'): setattr(adaptive_norm_params, 'separate_projection', activation_norm_params.separate_projection) adaptive_norm_params.activation_norm_params = types.SimpleNamespace() setattr(adaptive_norm_params.activation_norm_params, 'affine', activation_norm_params.activation_norm_params.affine) base_cbn2d_block = \ functools.partial(Conv2dBlock, kernel_size=kernel_size, stride=1, padding=padding, bias=True, weight_norm_type=weight_norm_type, activation_norm_type='adaptive', activation_norm_params=adaptive_norm_params, nonlinearity=nonlinearity, order='NAC') else: base_conv2d_block = \ functools.partial(Conv2dBlock, kernel_size=kernel_size, stride=1, padding=padding, bias=True, weight_norm_type=weight_norm_type, nonlinearity=nonlinearity, order='NAC') in_num_labels = num_labels in_num_labels += 2 if self.use_posenc_in_input_layer else 0 self.head_0 = Conv2dBlock(in_num_labels, 8 * num_filters, kernel_size=kernel_size, stride=1, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type='none', nonlinearity=nonlinearity) if self.use_style_encoder: self.cbn_head_0 = base_cbn2d_block( 8 * num_filters, 16 * num_filters) else: self.conv_head_0 = base_conv2d_block( 8 * num_filters, 16 * num_filters) self.head_1 = base_res2d_block(16 * num_filters, 16 * num_filters) self.head_2 = base_res2d_block(16 * num_filters, 16 * num_filters) self.up_0a = base_res2d_block(16 * num_filters, 8 * num_filters) if self.use_style_encoder: self.cbn_up_0a = base_cbn2d_block( 8 * num_filters, 8 * num_filters) else: self.conv_up_0a = base_conv2d_block( 8 * num_filters, 8 * num_filters) self.up_0b = base_res2d_block(8 * num_filters, 8 * num_filters) self.up_1a = base_res2d_block(8 * num_filters, 4 * num_filters) if self.use_style_encoder: self.cbn_up_1a = base_cbn2d_block( 4 * num_filters, 4 * num_filters) else: self.conv_up_1a = base_conv2d_block( 4 * num_filters, 4 * num_filters) self.up_1b = base_res2d_block(4 * num_filters, 4 * num_filters) self.up_2a = base_res2d_block(4 * num_filters, 4 * num_filters) if self.use_style_encoder: self.cbn_up_2a = base_cbn2d_block( 4 * num_filters, 4 * num_filters) else: self.conv_up_2a = base_conv2d_block( 4 * num_filters, 4 * num_filters) self.up_2b = base_res2d_block(4 * num_filters, 2 * num_filters) self.conv_img256 = Conv2dBlock(2 * num_filters, image_channels, 5, stride=1, padding=2, weight_norm_type=weight_norm_type, activation_norm_type='none', nonlinearity=nonlinearity, order='ANC') self.base = 16 if self.out_image_small_side_size == 512: self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters) self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters) self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels, 5, stride=1, padding=2, weight_norm_type=weight_norm_type, activation_norm_type='none', nonlinearity=nonlinearity, order='ANC') self.base = 32 if self.out_image_small_side_size == 1024: self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters) self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters) self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels, 5, stride=1, padding=2, weight_norm_type=weight_norm_type, activation_norm_type='none', nonlinearity=nonlinearity, order='ANC') self.up_4a = base_res2d_block(num_filters, num_filters // 2) self.up_4b = base_res2d_block(num_filters // 2, num_filters // 2) self.conv_img1024 = Conv2dBlock(num_filters // 2, image_channels, 5, stride=1, padding=2, weight_norm_type=weight_norm_type, activation_norm_type='none', nonlinearity=nonlinearity, order='ANC') self.nearest_upsample4x = NearestUpsample(scale_factor=4, mode='nearest') self.base = 64 if self.out_image_small_side_size != 256 and self.out_image_small_side_size != 512 \ and self.out_image_small_side_size != 1024: raise ValueError('Generation image size (%d, %d) not supported' % (self.out_image_small_side_size, self.out_image_small_side_size)) self.nearest_upsample2x = NearestUpsample(scale_factor=2, mode='nearest') xv, yv = torch.meshgrid( [torch.arange(-1, 1.1, 2. / 15), torch.arange(-1, 1.1, 2. / 15)]) self.xy = torch.cat((xv.unsqueeze(0), yv.unsqueeze(0)), 0).unsqueeze(0) self.xy = self.xy.cuda()
[docs] def forward(self, data): r"""SPADE Generator forward. Args: data (dict): - data (N x C1 x H x W tensor) : Ground truth images. - label (N x C2 x H x W tensor) : Semantic representations. - z (N x style_dims tensor): Gaussian random noise. Returns: output (dict): - fake_images (N x 3 x H x W tensor): Fake images. """ seg = data['label'] if self.use_style_encoder: z = data['z'] z = self.fc_0(z) z = self.fc_1(z) # The code piece below makes sure that the input size is always 16x16 sy = math.floor(seg.size()[2] * 1.0 / self.base) sx = math.floor(seg.size()[3] * 1.0 / self.base) in_seg = F.interpolate(seg, size=[sy, sx], mode='nearest') if self.use_posenc_in_input_layer: in_xy = F.interpolate(self.xy, size=[sy, sx], mode='bicubic') in_seg_xy = torch.cat( (in_seg, in_xy.expand(in_seg.size()[0], 2, sy, sx)), 1) else: in_seg_xy = in_seg # 16x16 x = self.head_0(in_seg_xy) if self.use_style_encoder: x = self.cbn_head_0(x, z) else: x = self.conv_head_0(x) x = self.head_1(x, seg) x = self.head_2(x, seg) x = self.nearest_upsample2x(x) # 32x32 x = self.up_0a(x, seg) if self.use_style_encoder: x = self.cbn_up_0a(x, z) else: x = self.conv_up_0a(x) x = self.up_0b(x, seg) x = self.nearest_upsample2x(x) # 64x64 x = self.up_1a(x, seg) if self.use_style_encoder: x = self.cbn_up_1a(x, z) else: x = self.conv_up_1a(x) x = self.up_1b(x, seg) x = self.nearest_upsample2x(x) # 128x128 x = self.up_2a(x, seg) if self.use_style_encoder: x = self.cbn_up_2a(x, z) else: x = self.conv_up_2a(x) x = self.up_2b(x, seg) x = self.nearest_upsample2x(x) # 256x256 if self.out_image_small_side_size == 256: x256 = self.conv_img256(x) x = torch.tanh(self.output_multiplier * x256) # 512x512 elif self.out_image_small_side_size == 512: x256 = self.conv_img256(x) x256 = self.nearest_upsample2x(x256) x = self.up_3a(x, seg) x = self.up_3b(x, seg) x = self.nearest_upsample2x(x) x512 = self.conv_img512(x) x = torch.tanh(self.output_multiplier * (x256 + x512)) # 1024x1024 elif self.out_image_small_side_size == 1024: x256 = self.conv_img256(x) x256 = self.nearest_upsample4x(x256) x = self.up_3a(x, seg) x = self.up_3b(x, seg) x = self.nearest_upsample2x(x) x512 = self.conv_img512(x) x512 = self.nearest_upsample2x(x512) x = self.up_4a(x, seg) x = self.up_4b(x, seg) x = self.nearest_upsample2x(x) x1024 = self.conv_img1024(x) x = torch.tanh(self.output_multiplier * (x256 + x512 + x1024)) output = dict() output['fake_images'] = x return output
[docs]class StyleEncoder(nn.Module): r"""Style Encode constructor. Args: style_enc_cfg (obj): Style encoder definition file. """ def __init__(self, style_enc_cfg): super(StyleEncoder, self).__init__() input_image_channels = style_enc_cfg.input_image_channels num_filters = style_enc_cfg.num_filters kernel_size = style_enc_cfg.kernel_size padding = int(np.ceil((kernel_size - 1.0) / 2)) style_dims = style_enc_cfg.style_dims weight_norm_type = style_enc_cfg.weight_norm_type activation_norm_type = 'none' nonlinearity = 'leakyrelu' base_conv2d_block = \ functools.partial(Conv2dBlock, kernel_size=kernel_size, stride=2, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, # inplace_nonlinearity=True, nonlinearity=nonlinearity) self.layer1 = base_conv2d_block(input_image_channels, num_filters) self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2) self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4) self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8) self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8) self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8) self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims) self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
[docs] def forward(self, input_x): r"""SPADE Style Encoder forward. Args: input_x (N x 3 x H x W tensor): input images. Returns: (tuple): - mu (N x C tensor): Mean vectors. - logvar (N x C tensor): Log-variance vectors. - z (N x C tensor): Style code vectors. """ if input_x.size(2) != 256 or input_x.size(3) != 256: input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear') x = self.layer1(input_x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.layer5(x) x = self.layer6(x) x = x.view(x.size(0), -1) mu = self.fc_mu(x) logvar = self.fc_var(x) std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = eps.mul(std) + mu return mu, logvar, z