Source code for imaginaire.layers.residual_deep

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint

from imaginaire.third_party.upfirdn2d import BlurDownsample, BlurUpsample
from .conv import Conv2dBlock


class _BaseDeepResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding, dilation, groups, bias, padding_mode,
                 weight_norm_type, weight_norm_params,
                 activation_norm_type, activation_norm_params,
                 skip_activation_norm, skip_nonlinearity,
                 nonlinearity, inplace_nonlinearity, apply_noise,
                 hidden_channels_equal_out_channels,
                 order, block, learn_shortcut, output_scale, skip_block=None,
                 blur=True, border_free=True, resample_first=True,
                 skip_weight_norm=True, hidden_channel_ratio=4):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.output_scale = output_scale
        self.resample_first = resample_first
        self.stride = stride
        self.blur = blur
        self.border_free = border_free
        assert not border_free
        if skip_block is None:
            skip_block = block

        if order == 'pre_act':
            order = 'NACNAC'
        if isinstance(bias, bool):
            # The bias for conv_block_0, conv_block_1, and conv_block_s.
            biases = [bias, bias, bias]
        elif isinstance(bias, list):
            if len(bias) == 3:
                biases = bias
            else:
                raise ValueError('Bias list must be 3.')
        else:
            raise ValueError('Bias must be either an integer or s list.')
        self.learn_shortcut = learn_shortcut
        if len(order) > 6 or len(order) < 5:
            raise ValueError('order must be either 5 or 6 characters')
        hidden_channels = in_channels // hidden_channel_ratio

        # Parameters.
        residual_params = {}
        shortcut_params = {}
        base_params = dict(dilation=dilation,
                           groups=groups,
                           padding_mode=padding_mode)
        residual_params.update(base_params)
        residual_params.update(
            dict(activation_norm_type=activation_norm_type,
                 activation_norm_params=activation_norm_params,
                 weight_norm_type=weight_norm_type,
                 weight_norm_params=weight_norm_params,
                 apply_noise=apply_noise)
        )
        shortcut_params.update(base_params)
        shortcut_params.update(dict(kernel_size=1))
        if skip_activation_norm:
            shortcut_params.update(
                dict(activation_norm_type=activation_norm_type,
                     activation_norm_params=activation_norm_params,
                     apply_noise=False))
        if skip_weight_norm:
            shortcut_params.update(
                dict(weight_norm_type=weight_norm_type,
                     weight_norm_params=weight_norm_params))

        # Residual branch.
        if order.find('A') < order.find('C') and \
                (activation_norm_type == '' or activation_norm_type == 'none'):
            # Nonlinearity is the first operation in the residual path.
            # In-place nonlinearity will modify the input variable and cause
            # backward error.
            first_inplace = False
        else:
            first_inplace = inplace_nonlinearity

        (first_stride, second_stride, shortcut_stride,
         first_blur, second_blur, shortcut_blur) = self._get_stride_blur()

        self.conv_block_1x1_in = block(
            in_channels, hidden_channels,
            1, 1, 0,
            bias=biases[0],
            nonlinearity=nonlinearity,
            order=order[0:3],
            inplace_nonlinearity=first_inplace,
            **residual_params
        )

        self.conv_block_0 = block(
            hidden_channels, hidden_channels,
            kernel_size=2 if self.border_free and first_stride < 1 else
            kernel_size,
            padding=padding,
            bias=biases[0],
            nonlinearity=nonlinearity,
            order=order[0:3],
            inplace_nonlinearity=inplace_nonlinearity,
            stride=first_stride,
            blur=first_blur,
            **residual_params
        )
        self.conv_block_1 = block(
            hidden_channels, hidden_channels,
            kernel_size=kernel_size,
            padding=padding,
            bias=biases[1],
            nonlinearity=nonlinearity,
            order=order[3:],
            inplace_nonlinearity=inplace_nonlinearity,
            stride=second_stride,
            blur=second_blur,
            **residual_params
        )

        self.conv_block_1x1_out = block(
            hidden_channels, out_channels,
            1, 1, 0,
            bias=biases[1],
            nonlinearity=nonlinearity,
            order=order[0:3],
            inplace_nonlinearity=inplace_nonlinearity,
            **residual_params
        )

        # Shortcut branch.
        if self.learn_shortcut:
            if skip_nonlinearity:
                skip_nonlinearity_type = nonlinearity
            else:
                skip_nonlinearity_type = ''
            self.conv_block_s = skip_block(in_channels, out_channels,
                                           bias=biases[2],
                                           nonlinearity=skip_nonlinearity_type,
                                           order=order[0:3],
                                           stride=shortcut_stride,
                                           blur=shortcut_blur,
                                           **shortcut_params)
        elif in_channels < out_channels:
            if skip_nonlinearity:
                skip_nonlinearity_type = nonlinearity
            else:
                skip_nonlinearity_type = ''
            self.conv_block_s = skip_block(in_channels,
                                           out_channels - in_channels,
                                           bias=biases[2],
                                           nonlinearity=skip_nonlinearity_type,
                                           order=order[0:3],
                                           stride=shortcut_stride,
                                           blur=shortcut_blur,
                                           **shortcut_params)

        # Whether this block expects conditional inputs.
        self.conditional = \
            getattr(self.conv_block_0, 'conditional', False) or \
            getattr(self.conv_block_1, 'conditional', False) or \
            getattr(self.conv_block_1x1_in, 'conditional', False) or \
            getattr(self.conv_block_1x1_out, 'conditional', False)

    def _get_stride_blur(self):
        if self.stride > 1:
            # Downsampling.
            first_stride, second_stride = 1, self.stride
            first_blur, second_blur = False, self.blur
            shortcut_blur = False
            shortcut_stride = 1
            if self.blur:
                # The shortcut branch uses blur_downsample + stride-1 conv
                if self.border_free:
                    self.resample = nn.AvgPool2d(2)
                else:
                    self.resample = BlurDownsample()
            else:
                shortcut_stride = self.stride
                self.resample = nn.AvgPool2d(2)
        elif self.stride < 1:
            # Upsampling.
            first_stride, second_stride = self.stride, 1
            first_blur, second_blur = self.blur, False
            shortcut_blur = False
            shortcut_stride = 1
            if self.blur:
                # The shortcut branch uses blur_upsample + stride-1 conv
                if self.border_free:
                    self.resample = nn.Upsample(scale_factor=2,
                                                mode='bilinear')
                else:
                    self.resample = BlurUpsample()
            else:
                shortcut_stride = self.stride
                self.resample = nn.Upsample(scale_factor=2)
        else:
            first_stride = second_stride = 1
            first_blur = second_blur = False
            shortcut_stride = 1
            shortcut_blur = False
            self.resample = None
        return (first_stride, second_stride, shortcut_stride,
                first_blur, second_blur, shortcut_blur)

    def conv_blocks(
            self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs
    ):
        if separate_cond:
            assert len(list(cond_inputs)) == 4
            dx = self.conv_block_1x1_in(x, cond_inputs[0],
                                        **kw_cond_inputs.get('kwargs_0', {}))
            dx = self.conv_block_0(dx, cond_inputs[1],
                                   **kw_cond_inputs.get('kwargs_1', {}))
            dx = self.conv_block_1(dx, cond_inputs[2],
                                   **kw_cond_inputs.get('kwargs_2', {}))
            dx = self.conv_block_1x1_out(dx, cond_inputs[3],
                                         **kw_cond_inputs.get('kwargs_3', {}))
        else:
            dx = self.conv_block_1x1_in(x, *cond_inputs, **kw_cond_inputs)
            dx = self.conv_block_0(dx, *cond_inputs, **kw_cond_inputs)
            dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs)
            dx = self.conv_block_1x1_out(dx, *cond_inputs, **kw_cond_inputs)
        return dx

    def forward(self, x, *cond_inputs, do_checkpoint=False, **kw_cond_inputs):
        if do_checkpoint:
            dx = checkpoint(self.conv_blocks, x, *cond_inputs, **kw_cond_inputs)
        else:
            dx = self.conv_blocks(x, *cond_inputs, **kw_cond_inputs)

        if self.resample_first and self.resample is not None:
            x = self.resample(x)
        if self.learn_shortcut:
            x_shortcut = self.conv_block_s(
                x, *cond_inputs, **kw_cond_inputs
            )
        elif self.in_channels < self.out_channels:
            x_shortcut_pad = self.conv_block_s(
                x, *cond_inputs, **kw_cond_inputs
            )
            x_shortcut = torch.cat((x, x_shortcut_pad), dim=1)
        elif self.in_channels > self.out_channels:
            x_shortcut = x[:, :self.out_channels, :, :]
        else:
            x_shortcut = x
        if not self.resample_first and self.resample is not None:
            x_shortcut = self.resample(x_shortcut)

        output = x_shortcut + dx
        return self.output_scale * output

    def extra_repr(self):
        s = 'output_scale={output_scale}'
        return s.format(**self.__dict__)


[docs]class DeepRes2dBlock(_BaseDeepResBlock): r"""Residual block for 2D input. Args: in_channels (int) : Number of channels in the input tensor. out_channels (int) : Number of channels in the output tensor. kernel_size (int, optional, default=3): Kernel size for the convolutional filters in the residual link. padding (int, optional, default=1): Padding size. dilation (int, optional, default=1): Dilation factor. groups (int, optional, default=1): Number of convolutional/linear groups. padding_mode (string, optional, default='zeros'): Type of padding: ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. weight_norm_type (str, optional, default='none'): Type of weight normalization. ``'none'``, ``'spectral'``, ``'weight'`` or ``'weight_demod'``. weight_norm_params (obj, optional, default=None): Parameters of weight normalization. If not ``None``, ``weight_norm_params.__dict__`` will be used as keyword arguments when initializing weight normalization. activation_norm_type (str, optional, default='none'): Type of activation normalization. ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. activation_norm_params (obj, optional, default=None): Parameters of activation normalization. If not ``None``, ``activation_norm_params.__dict__`` will be used as keyword arguments when initializing activation normalization. skip_activation_norm (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies activation norm to the learned shortcut connection. skip_nonlinearity (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies nonlinearity to the learned shortcut connection. nonlinearity (str, optional, default='none'): Type of nonlinear activation function in the residual link. ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. inplace_nonlinearity (bool, optional, default=False): If ``True``, set ``inplace=True`` when initializing the nonlinearity layers. apply_noise (bool, optional, default=False): If ``True``, adds Gaussian noise with learnable magnitude to the convolution output. hidden_channels_equal_out_channels (bool, optional, default=False): If ``True``, set the hidden channel number to be equal to the output channel number. If ``False``, the hidden channel number equals to the smaller of the input channel number and the output channel number. order (str, optional, default='CNACNA'): Order of operations in the residual link. ``'C'``: convolution, ``'N'``: normalization, ``'A'``: nonlinear activation. learn_shortcut (bool, optional, default=False): If ``True``, always use a convolutional shortcut instead of an identity one, otherwise only use a convolutional one if input and output have different number of channels. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros', weight_norm_type='none', weight_norm_params=None, activation_norm_type='none', activation_norm_params=None, skip_activation_norm=True, skip_nonlinearity=False, skip_weight_norm=True, nonlinearity='leakyrelu', inplace_nonlinearity=False, apply_noise=False, hidden_channels_equal_out_channels=False, order='CNACNA', learn_shortcut=False, output_scale=1, blur=True, resample_first=True, border_free=False): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, Conv2dBlock, learn_shortcut, output_scale, blur=blur, resample_first=resample_first, border_free=border_free, skip_weight_norm=skip_weight_norm)