Source code for imaginaire.generators.unit

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import warnings

from torch import nn
from torch.nn import Upsample as NearestUpsample

from imaginaire.layers import Conv2dBlock, Res2dBlock


[docs]class Generator(nn.Module): r"""Improved UNIT generator. Args: gen_cfg (obj): Generator definition part of the yaml config file. data_cfg (obj): Data definition part of the yaml config file. """ def __init__(self, gen_cfg, data_cfg): super().__init__() self.autoencoder_a = AutoEncoder(**vars(gen_cfg)) self.autoencoder_b = AutoEncoder(**vars(gen_cfg))
[docs] def forward(self, data, image_recon=True, cycle_recon=True): r"""UNIT forward function""" images_a = data['images_a'] images_b = data['images_b'] net_G_output = dict() # encode input images into latent code content_a = self.autoencoder_a.content_encoder(images_a) content_b = self.autoencoder_b.content_encoder(images_b) # decode (within domain) if image_recon: images_aa = self.autoencoder_a.decoder(content_a) images_bb = self.autoencoder_b.decoder(content_b) net_G_output.update(dict(images_aa=images_aa, images_bb=images_bb)) # decode (cross domain) images_ba = self.autoencoder_a.decoder(content_b) images_ab = self.autoencoder_b.decoder(content_a) # cycle reconstruction if cycle_recon: content_ba = self.autoencoder_a.content_encoder(images_ba) content_ab = self.autoencoder_b.content_encoder(images_ab) images_aba = self.autoencoder_a.decoder(content_ab) images_bab = self.autoencoder_b.decoder(content_ba) net_G_output.update( dict(content_ba=content_ba, content_ab=content_ab, images_aba=images_aba, images_bab=images_bab)) # required outputs net_G_output.update(dict(content_a=content_a, content_b=content_b, images_ba=images_ba, images_ab=images_ab)) return net_G_output
[docs] def inference(self, data, a2b=True): r"""UNIT inference. Args: data (dict): Training data at the current iteration. - images_a (tensor): Images from domain A. - images_b (tensor): Images from domain B. a2b (bool): If ``True``, translates images from domain A to B, otherwise from B to A. """ if a2b: input_key = 'images_a' content_encode = self.autoencoder_a.content_encoder decode = self.autoencoder_b.decoder else: input_key = 'images_b' content_encode = self.autoencoder_b.content_encoder decode = self.autoencoder_a.decoder content_images = data[input_key] content = content_encode(content_images) output_images = decode(content) filename = '%s/%s' % ( data['key'][input_key]['sequence_name'][0], data['key'][input_key]['filename'][0]) filenames = [filename] return output_images, filenames
[docs]class AutoEncoder(nn.Module): r"""Improved UNIT autoencoder. Args: num_filters (int): Base filter numbers. max_num_filters (int): Maximum number of filters in the encoder. num_res_blocks (int): Number of residual blocks at the end of the content encoder. num_downsamples_content (int): Number of times we reduce resolution by 2x2 for the content image. num_image_channels (int): Number of input image channels. content_norm_type (str): Type of activation normalization in the content encoder. decoder_norm_type (str): Type of activation normalization in the decoder. weight_norm_type (str): Type of weight normalization. output_nonlinearity (str): Type of nonlinearity before final output, ``'tanh'`` or ``'none'``. pre_act (bool): If ``True``, uses pre-activation residual blocks. apply_noise (bool): If ``True``, injects Gaussian noise in the decoder. """ def __init__(self, num_filters=64, max_num_filters=256, num_res_blocks=4, num_downsamples_content=2, num_image_channels=3, content_norm_type='instance', decoder_norm_type='instance', weight_norm_type='', output_nonlinearity='', pre_act=False, apply_noise=False, **kwargs): super().__init__() for key in kwargs: if key != 'type': warnings.warn( "Generator argument '{}' is not used.".format(key)) self.content_encoder = ContentEncoder(num_downsamples_content, num_res_blocks, num_image_channels, num_filters, max_num_filters, 'reflect', content_norm_type, weight_norm_type, 'relu', pre_act) self.decoder = Decoder(num_downsamples_content, num_res_blocks, self.content_encoder.output_dim, num_image_channels, 'reflect', decoder_norm_type, weight_norm_type, 'relu', output_nonlinearity, pre_act, apply_noise)
[docs] def forward(self, images): r"""Reconstruct an image. Args: images (Tensor): Input images. Returns: images_recon (Tensor): Reconstructed images. """ content = self.content_encoder(images) images_recon = self.decoder(content) return images_recon
[docs]class ContentEncoder(nn.Module): r"""Improved UNIT encoder. The network consists of: - input layers - $(num_downsamples) convolutional blocks - $(num_res_blocks) residual blocks. - output layer. Args: num_downsamples (int): Number of times we reduce resolution by 2x2. num_res_blocks (int): Number of residual blocks at the end of the content encoder. num_image_channels (int): Number of input image channels. num_filters (int): Base filter numbers. max_num_filters (int): Maximum number of filters in the encoder. padding_mode (string): Type of padding. activation_norm_type (str): Type of activation normalization. weight_norm_type (str): Type of weight normalization. nonlinearity (str): Type of nonlinear activation function. pre_act (bool): If ``True``, uses pre-activation residual blocks. """ def __init__(self, num_downsamples, num_res_blocks, num_image_channels, num_filters, max_num_filters, padding_mode, activation_norm_type, weight_norm_type, nonlinearity, pre_act=False): super().__init__() conv_params = dict(padding_mode=padding_mode, activation_norm_type=activation_norm_type, weight_norm_type=weight_norm_type, nonlinearity=nonlinearity) # Whether or not it is safe to use inplace nonlinear activation. if not pre_act or (activation_norm_type != '' and activation_norm_type != 'none'): conv_params['inplace_nonlinearity'] = True # The order of operations in residual blocks. order = 'pre_act' if pre_act else 'CNACNA' model = [] model += [Conv2dBlock(num_image_channels, num_filters, 7, 1, 3, **conv_params)] # Downsampling blocks. for i in range(num_downsamples): num_filters_prev = num_filters num_filters = min(num_filters * 2, max_num_filters) model += [Conv2dBlock(num_filters_prev, num_filters, 4, 2, 1, **conv_params)] # Residual blocks. for _ in range(num_res_blocks): model += [Res2dBlock(num_filters, num_filters, **conv_params, order=order)] self.model = nn.Sequential(*model) self.output_dim = num_filters
[docs] def forward(self, x): r""" Args: x (tensor): Input image. """ return self.model(x)
[docs]class Decoder(nn.Module): r"""Improved UNIT decoder. The network consists of: - $(num_res_blocks) residual blocks. - $(num_upsamples) residual blocks or convolutional blocks - output layer. Args: num_upsamples (int): Number of times we increase resolution by 2x2. num_res_blocks (int): Number of residual blocks. num_filters (int): Base filter numbers. num_image_channels (int): Number of input image channels. padding_mode (string): Type of padding. activation_norm_type (str): Type of activation normalization. weight_norm_type (str): Type of weight normalization. nonlinearity (str): Type of nonlinear activation function. output_nonlinearity (str): Type of nonlinearity before final output, ``'tanh'`` or ``'none'``. pre_act (bool): If ``True``, uses pre-activation residual blocks. apply_noise (bool): If ``True``, injects Gaussian noise. """ def __init__(self, num_upsamples, num_res_blocks, num_filters, num_image_channels, padding_mode, activation_norm_type, weight_norm_type, nonlinearity, output_nonlinearity, pre_act=False, apply_noise=False): super().__init__() conv_params = dict(padding_mode=padding_mode, nonlinearity=nonlinearity, inplace_nonlinearity=True, apply_noise=apply_noise, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type) # The order of operations in residual blocks. order = 'pre_act' if pre_act else 'CNACNA' # Residual blocks. self.decoder = nn.ModuleList() for _ in range(num_res_blocks): self.decoder += [Res2dBlock(num_filters, num_filters, **conv_params, order=order)] # Convolutional blocks with upsampling. for i in range(num_upsamples): self.decoder += [NearestUpsample(scale_factor=2)] self.decoder += [Conv2dBlock(num_filters, num_filters // 2, 5, 1, 2, **conv_params)] num_filters //= 2 self.decoder += [Conv2dBlock(num_filters, num_image_channels, 7, 1, 3, nonlinearity=output_nonlinearity, padding_mode=padding_mode)]
[docs] def forward(self, x): r""" Args: x (tensor): Content embedding of the content image. """ for block in self.decoder: x = block(x) return x