Source code for imaginaire.layers.residual

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import functools

import torch
from torch import nn
from torch.nn import Upsample as NearestUpsample
from torch.utils.checkpoint import checkpoint

from .conv import (Conv1dBlock, Conv2dBlock, Conv3dBlock, HyperConv2dBlock,
                   LinearBlock, MultiOutConv2dBlock, PartialConv2dBlock,
                   PartialConv3dBlock, ModulatedConv2dBlock)
from imaginaire.third_party.upfirdn2d.upfirdn2d import BlurUpsample


class _BaseResBlock(nn.Module):
    r"""An abstract class for residual blocks.
    """

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding, dilation, groups, bias, padding_mode,
                 weight_norm_type, weight_norm_params,
                 activation_norm_type, activation_norm_params,
                 skip_activation_norm, skip_nonlinearity,
                 nonlinearity, inplace_nonlinearity, apply_noise,
                 hidden_channels_equal_out_channels,
                 order, block, learn_shortcut, clamp, output_scale,
                 skip_block=None, blur=False, upsample_first=True, skip_weight_norm=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.output_scale = output_scale
        self.upsample_first = upsample_first
        self.stride = stride
        self.blur = blur
        if skip_block is None:
            skip_block = block

        if order == 'pre_act':
            order = 'NACNAC'
        if isinstance(bias, bool):
            # The bias for conv_block_0, conv_block_1, and conv_block_s.
            biases = [bias, bias, bias]
        elif isinstance(bias, list):
            if len(bias) == 3:
                biases = bias
            else:
                raise ValueError('Bias list must be 3.')
        else:
            raise ValueError('Bias must be either an integer or s list.')
        if learn_shortcut is None:
            self.learn_shortcut = (in_channels != out_channels)
        else:
            self.learn_shortcut = learn_shortcut
        if len(order) > 6 or len(order) < 5:
            raise ValueError('order must be either 5 or 6 characters')
        if hidden_channels_equal_out_channels:
            hidden_channels = out_channels
        else:
            hidden_channels = min(in_channels, out_channels)

        # Parameters.
        residual_params = {}
        shortcut_params = {}
        base_params = dict(dilation=dilation,
                           groups=groups,
                           padding_mode=padding_mode,
                           clamp=clamp)
        residual_params.update(base_params)
        residual_params.update(
            dict(activation_norm_type=activation_norm_type,
                 activation_norm_params=activation_norm_params,
                 weight_norm_type=weight_norm_type,
                 weight_norm_params=weight_norm_params,
                 padding=padding,
                 apply_noise=apply_noise))
        shortcut_params.update(base_params)
        shortcut_params.update(dict(kernel_size=1))
        if skip_activation_norm:
            shortcut_params.update(
                dict(activation_norm_type=activation_norm_type,
                     activation_norm_params=activation_norm_params,
                     apply_noise=False))
        if skip_weight_norm:
            shortcut_params.update(
                dict(weight_norm_type=weight_norm_type,
                     weight_norm_params=weight_norm_params))

        # Residual branch.
        if order.find('A') < order.find('C') and \
                (activation_norm_type == '' or activation_norm_type == 'none'):
            # Nonlinearity is the first operation in the residual path.
            # In-place nonlinearity will modify the input variable and cause
            # backward error.
            first_inplace = False
        else:
            first_inplace = inplace_nonlinearity

        (first_stride, second_stride, shortcut_stride,
         first_blur, second_blur, shortcut_blur) = self._get_stride_blur()
        self.conv_block_0 = block(
            in_channels, hidden_channels,
            kernel_size=kernel_size,
            bias=biases[0],
            nonlinearity=nonlinearity,
            order=order[0:3],
            inplace_nonlinearity=first_inplace,
            stride=first_stride,
            blur=first_blur,
            **residual_params
        )
        self.conv_block_1 = block(
            hidden_channels, out_channels,
            kernel_size=kernel_size,
            bias=biases[1],
            nonlinearity=nonlinearity,
            order=order[3:],
            inplace_nonlinearity=inplace_nonlinearity,
            stride=second_stride,
            blur=second_blur,
            **residual_params
        )

        # Shortcut branch.
        if self.learn_shortcut:
            if skip_nonlinearity:
                skip_nonlinearity_type = nonlinearity
            else:
                skip_nonlinearity_type = ''
            self.conv_block_s = skip_block(in_channels, out_channels,
                                           bias=biases[2],
                                           nonlinearity=skip_nonlinearity_type,
                                           order=order[0:3],
                                           stride=shortcut_stride,
                                           blur=shortcut_blur,
                                           **shortcut_params)
        elif in_channels < out_channels:
            if skip_nonlinearity:
                skip_nonlinearity_type = nonlinearity
            else:
                skip_nonlinearity_type = ''
            self.conv_block_s = skip_block(in_channels,
                                           out_channels - in_channels,
                                           bias=biases[2],
                                           nonlinearity=skip_nonlinearity_type,
                                           order=order[0:3],
                                           stride=shortcut_stride,
                                           blur=shortcut_blur,
                                           **shortcut_params)

        # Whether this block expects conditional inputs.
        self.conditional = \
            getattr(self.conv_block_0, 'conditional', False) or \
            getattr(self.conv_block_1, 'conditional', False)

    def _get_stride_blur(self):
        if self.stride > 1:
            # Downsampling.
            first_stride, second_stride = 1, self.stride
            first_blur, second_blur = False, self.blur
            shortcut_stride = self.stride
            shortcut_blur = self.blur
            self.upsample = None
        elif self.stride < 1:
            # Upsampling.
            first_stride, second_stride = self.stride, 1
            first_blur, second_blur = self.blur, False
            shortcut_blur = False
            shortcut_stride = 1
            if self.blur:
                # The shortcut branch uses blur_upsample + stride-1 conv
                self.upsample = BlurUpsample()
            else:
                shortcut_stride = self.stride
                self.upsample = nn.Upsample(scale_factor=2)
        else:
            first_stride = second_stride = 1
            first_blur = second_blur = False
            shortcut_stride = 1
            shortcut_blur = False
            self.upsample = None
        return (first_stride, second_stride, shortcut_stride,
                first_blur, second_blur, shortcut_blur)

    def conv_blocks(
            self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs
    ):
        r"""Returns the output of the residual branch.

        Args:
            x (tensor): Input tensor.
            cond_inputs (list of tensors) : Conditional input tensors.
            kw_cond_inputs (dict) : Keyword conditional inputs.
        Returns:
            dx (tensor): Output tensor.
        """
        if separate_cond:
            dx = self.conv_block_0(x, cond_inputs[0],
                                   **kw_cond_inputs.get('kwargs_0', {}))
            dx = self.conv_block_1(dx, cond_inputs[1],
                                   **kw_cond_inputs.get('kwargs_1', {}))
        else:
            dx = self.conv_block_0(x, *cond_inputs, **kw_cond_inputs)
            dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs)
        return dx

    def forward(self, x, *cond_inputs, do_checkpoint=False, separate_cond=False,
                **kw_cond_inputs):
        r"""

        Args:
            x (tensor): Input tensor.
            cond_inputs (list of tensors) : Conditional input tensors.
            do_checkpoint (bool, optional, default=``False``) If ``True``,
                trade compute for memory by checkpointing the model.
            kw_cond_inputs (dict) : Keyword conditional inputs.
        Returns:
            output (tensor): Output tensor.
        """
        if do_checkpoint:
            dx = checkpoint(self.conv_blocks, x, *cond_inputs,
                            separate_cond=separate_cond, **kw_cond_inputs)
        else:
            dx = self.conv_blocks(x, *cond_inputs,
                                  separate_cond=separate_cond, **kw_cond_inputs)

        if self.upsample_first and self.upsample is not None:
            x = self.upsample(x)
        if self.learn_shortcut:
            if separate_cond:
                x_shortcut = self.conv_block_s(
                    x, cond_inputs[2], **kw_cond_inputs.get('kwargs_2', {})
                )
            else:
                x_shortcut = self.conv_block_s(
                    x, *cond_inputs, **kw_cond_inputs
                )
        elif self.in_channels < self.out_channels:
            if separate_cond:
                x_shortcut_pad = self.conv_block_s(
                    x, cond_inputs[2], **kw_cond_inputs.get('kwargs_2', {})
                )
            else:
                x_shortcut_pad = self.conv_block_s(
                    x, *cond_inputs, **kw_cond_inputs
                )
            x_shortcut = torch.cat((x, x_shortcut_pad), dim=1)
        elif self.in_channels > self.out_channels:
            x_shortcut = x[:, :self.out_channels, :, :]
        else:
            x_shortcut = x
        if not self.upsample_first and self.upsample is not None:
            x_shortcut = self.upsample(x_shortcut)

        output = x_shortcut + dx
        return self.output_scale * output

    def extra_repr(self):
        s = 'output_scale={output_scale}'
        return s.format(**self.__dict__)


[docs]class ModulatedRes2dBlock(_BaseResBlock): def __init__(self, in_channels, out_channels, style_dim, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros', weight_norm_type='none', weight_norm_params=None, activation_norm_type='none', activation_norm_params=None, skip_activation_norm=True, skip_nonlinearity=False, nonlinearity='leakyrelu', inplace_nonlinearity=False, apply_noise=True, hidden_channels_equal_out_channels=False, order='CNACNA', learn_shortcut=None, clamp=None, output_scale=1, demodulate=True, eps=1e-8): block = functools.partial(ModulatedConv2dBlock, style_dim=style_dim, demodulate=demodulate, eps=eps) skip_block = Conv2dBlock super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, block, learn_shortcut, clamp, output_scale, skip_block=skip_block)
[docs] def conv_blocks(self, x, *cond_inputs, **kw_cond_inputs): assert len(list(cond_inputs)) == 2 dx = self.conv_block_0(x, cond_inputs[0], **kw_cond_inputs) dx = self.conv_block_1(dx, cond_inputs[1], **kw_cond_inputs) return dx
[docs]class ResLinearBlock(_BaseResBlock): r"""Residual block with full-connected layers. Args: in_channels (int) : Number of channels in the input tensor. out_channels (int) : Number of channels in the output tensor. weight_norm_type (str, optional, default='none'): Type of weight normalization. ``'none'``, ``'spectral'``, ``'weight'`` or ``'weight_demod'``. weight_norm_params (obj, optional, default=None): Parameters of weight normalization. If not ``None``, ``weight_norm_params.__dict__`` will be used as keyword arguments when initializing weight normalization. activation_norm_type (str, optional, default='none'): Type of activation normalization. ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. activation_norm_params (obj, optional, default=None): Parameters of activation normalization. If not ``None``, ``activation_norm_params.__dict__`` will be used as keyword arguments when initializing activation normalization. skip_activation_norm (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies activation norm to the learned shortcut connection. skip_nonlinearity (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies nonlinearity to the learned shortcut connection. nonlinearity (str, optional, default='none'): Type of nonlinear activation function in the residual link. ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. inplace_nonlinearity (bool, optional, default=False): If ``True``, set ``inplace=True`` when initializing the nonlinearity layers. apply_noise (bool, optional, default=False): If ``True``, add Gaussian noise with learnable magnitude after the fully-connected layer. hidden_channels_equal_out_channels (bool, optional, default=False): If ``True``, set the hidden channel number to be equal to the output channel number. If ``False``, the hidden channel number equals to the smaller of the input channel number and the output channel number. order (str, optional, default='CNACNA'): Order of operations in the residual link. ``'C'``: fully-connected, ``'N'``: normalization, ``'A'``: nonlinear activation. learn_shortcut (bool, optional, default=False): If ``True``, always use a convolutional shortcut instead of an identity one, otherwise only use a convolutional one if input and output have different number of channels. """ def __init__(self, in_channels, out_channels, bias=True, weight_norm_type='none', weight_norm_params=None, activation_norm_type='none', activation_norm_params=None, skip_activation_norm=True, skip_nonlinearity=False, nonlinearity='leakyrelu', inplace_nonlinearity=False, apply_noise=False, hidden_channels_equal_out_channels=False, order='CNACNA', learn_shortcut=None, clamp=None, output_scale=1): super().__init__(in_channels, out_channels, None, 1, None, None, None, bias, None, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, LinearBlock, learn_shortcut, clamp, output_scale)
[docs]class Res1dBlock(_BaseResBlock): r"""Residual block for 1D input. Args: in_channels (int) : Number of channels in the input tensor. out_channels (int) : Number of channels in the output tensor. kernel_size (int, optional, default=3): Kernel size for the convolutional filters in the residual link. padding (int, optional, default=1): Padding size. dilation (int, optional, default=1): Dilation factor. groups (int, optional, default=1): Number of convolutional/linear groups. padding_mode (string, optional, default='zeros'): Type of padding: ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. weight_norm_type (str, optional, default='none'): Type of weight normalization. ``'none'``, ``'spectral'``, ``'weight'`` or ``'weight_demod'``. weight_norm_params (obj, optional, default=None): Parameters of weight normalization. If not ``None``, ``weight_norm_params.__dict__`` will be used as keyword arguments when initializing weight normalization. activation_norm_type (str, optional, default='none'): Type of activation normalization. ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. activation_norm_params (obj, optional, default=None): Parameters of activation normalization. If not ``None``, ``activation_norm_params.__dict__`` will be used as keyword arguments when initializing activation normalization. skip_activation_norm (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies activation norm to the learned shortcut connection. skip_nonlinearity (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies nonlinearity to the learned shortcut connection. nonlinearity (str, optional, default='none'): Type of nonlinear activation function in the residual link. ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. inplace_nonlinearity (bool, optional, default=False): If ``True``, set ``inplace=True`` when initializing the nonlinearity layers. apply_noise (bool, optional, default=False): If ``True``, adds Gaussian noise with learnable magnitude to the convolution output. hidden_channels_equal_out_channels (bool, optional, default=False): If ``True``, set the hidden channel number to be equal to the output channel number. If ``False``, the hidden channel number equals to the smaller of the input channel number and the output channel number. order (str, optional, default='CNACNA'): Order of operations in the residual link. ``'C'``: convolution, ``'N'``: normalization, ``'A'``: nonlinear activation. learn_shortcut (bool, optional, default=False): If ``True``, always use a convolutional shortcut instead of an identity one, otherwise only use a convolutional one if input and output have different number of channels. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros', weight_norm_type='none', weight_norm_params=None, activation_norm_type='none', activation_norm_params=None, skip_activation_norm=True, skip_nonlinearity=False, nonlinearity='leakyrelu', inplace_nonlinearity=False, apply_noise=False, hidden_channels_equal_out_channels=False, order='CNACNA', learn_shortcut=None, clamp=None, output_scale=1): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, Conv1dBlock, learn_shortcut, clamp, output_scale)
[docs]class Res2dBlock(_BaseResBlock): r"""Residual block for 2D input. Args: in_channels (int) : Number of channels in the input tensor. out_channels (int) : Number of channels in the output tensor. kernel_size (int, optional, default=3): Kernel size for the convolutional filters in the residual link. padding (int, optional, default=1): Padding size. dilation (int, optional, default=1): Dilation factor. groups (int, optional, default=1): Number of convolutional/linear groups. padding_mode (string, optional, default='zeros'): Type of padding: ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. weight_norm_type (str, optional, default='none'): Type of weight normalization. ``'none'``, ``'spectral'``, ``'weight'`` or ``'weight_demod'``. weight_norm_params (obj, optional, default=None): Parameters of weight normalization. If not ``None``, ``weight_norm_params.__dict__`` will be used as keyword arguments when initializing weight normalization. activation_norm_type (str, optional, default='none'): Type of activation normalization. ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. activation_norm_params (obj, optional, default=None): Parameters of activation normalization. If not ``None``, ``activation_norm_params.__dict__`` will be used as keyword arguments when initializing activation normalization. skip_activation_norm (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies activation norm to the learned shortcut connection. skip_nonlinearity (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies nonlinearity to the learned shortcut connection. nonlinearity (str, optional, default='none'): Type of nonlinear activation function in the residual link. ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. inplace_nonlinearity (bool, optional, default=False): If ``True``, set ``inplace=True`` when initializing the nonlinearity layers. apply_noise (bool, optional, default=False): If ``True``, adds Gaussian noise with learnable magnitude to the convolution output. hidden_channels_equal_out_channels (bool, optional, default=False): If ``True``, set the hidden channel number to be equal to the output channel number. If ``False``, the hidden channel number equals to the smaller of the input channel number and the output channel number. order (str, optional, default='CNACNA'): Order of operations in the residual link. ``'C'``: convolution, ``'N'``: normalization, ``'A'``: nonlinear activation. learn_shortcut (bool, optional, default=False): If ``True``, always use a convolutional shortcut instead of an identity one, otherwise only use a convolutional one if input and output have different number of channels. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros', weight_norm_type='none', weight_norm_params=None, activation_norm_type='none', activation_norm_params=None, skip_activation_norm=True, skip_nonlinearity=False, skip_weight_norm=True, nonlinearity='leakyrelu', inplace_nonlinearity=False, apply_noise=False, hidden_channels_equal_out_channels=False, order='CNACNA', learn_shortcut=None, clamp=None, output_scale=1, blur=False, upsample_first=True): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, Conv2dBlock, learn_shortcut, clamp, output_scale, blur=blur, upsample_first=upsample_first, skip_weight_norm=skip_weight_norm)
[docs]class Res3dBlock(_BaseResBlock): r"""Residual block for 3D input. Args: in_channels (int) : Number of channels in the input tensor. out_channels (int) : Number of channels in the output tensor. kernel_size (int, optional, default=3): Kernel size for the convolutional filters in the residual link. padding (int, optional, default=1): Padding size. dilation (int, optional, default=1): Dilation factor. groups (int, optional, default=1): Number of convolutional/linear groups. padding_mode (string, optional, default='zeros'): Type of padding: ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. weight_norm_type (str, optional, default='none'): Type of weight normalization. ``'none'``, ``'spectral'``, ``'weight'`` or ``'weight_demod'``. weight_norm_params (obj, optional, default=None): Parameters of weight normalization. If not ``None``, ``weight_norm_params.__dict__`` will be used as keyword arguments when initializing weight normalization. activation_norm_type (str, optional, default='none'): Type of activation normalization. ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. activation_norm_params (obj, optional, default=None): Parameters of activation normalization. If not ``None``, ``activation_norm_params.__dict__`` will be used as keyword arguments when initializing activation normalization. skip_activation_norm (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies activation norm to the learned shortcut connection. skip_nonlinearity (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies nonlinearity to the learned shortcut connection. nonlinearity (str, optional, default='none'): Type of nonlinear activation function in the residual link. ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. inplace_nonlinearity (bool, optional, default=False): If ``True``, set ``inplace=True`` when initializing the nonlinearity layers. apply_noise (bool, optional, default=False): If ``True``, adds Gaussian noise with learnable magnitude to the convolution output. hidden_channels_equal_out_channels (bool, optional, default=False): If ``True``, set the hidden channel number to be equal to the output channel number. If ``False``, the hidden channel number equals to the smaller of the input channel number and the output channel number. order (str, optional, default='CNACNA'): Order of operations in the residual link. ``'C'``: convolution, ``'N'``: normalization, ``'A'``: nonlinear activation. learn_shortcut (bool, optional, default=False): If ``True``, always use a convolutional shortcut instead of an identity one, otherwise only use a convolutional one if input and output have different number of channels. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros', weight_norm_type='none', weight_norm_params=None, activation_norm_type='none', activation_norm_params=None, skip_activation_norm=True, skip_nonlinearity=False, nonlinearity='leakyrelu', inplace_nonlinearity=False, apply_noise=False, hidden_channels_equal_out_channels=False, order='CNACNA', learn_shortcut=None, clamp=None, output_scale=1): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, Conv3dBlock, learn_shortcut, clamp, output_scale)
class _BaseHyperResBlock(_BaseResBlock): r"""An abstract class for hyper residual blocks. """ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, is_hyper_conv, is_hyper_norm, block, learn_shortcut, clamp=None, output_scale=1): block = functools.partial(block, is_hyper_conv=is_hyper_conv, is_hyper_norm=is_hyper_norm) super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, block, learn_shortcut, clamp, output_scale) def forward(self, x, *cond_inputs, conv_weights=(None,) * 3, norm_weights=(None,) * 3, **kw_cond_inputs): r""" Args: x (tensor): Input tensor. cond_inputs (list of tensors) : Conditional input tensors. conv_weights (list of tensors): Convolution weights for three convolutional layers respectively. norm_weights (list of tensors): Normalization weights for three convolutional layers respectively. kw_cond_inputs (dict) : Keyword conditional inputs. Returns: output (tensor): Output tensor. """ dx = self.conv_block_0(x, *cond_inputs, conv_weights=conv_weights[0], norm_weights=norm_weights[0]) dx = self.conv_block_1(dx, *cond_inputs, conv_weights=conv_weights[1], norm_weights=norm_weights[1]) if self.learn_shortcut: x_shortcut = self.conv_block_s(x, *cond_inputs, conv_weights=conv_weights[2], norm_weights=norm_weights[2]) else: x_shortcut = x output = x_shortcut + dx return self.output_scale * output
[docs]class HyperRes2dBlock(_BaseHyperResBlock): r"""Hyper residual block for 2D input. Args: in_channels (int) : Number of channels in the input tensor. out_channels (int) : Number of channels in the output tensor. kernel_size (int, optional, default=3): Kernel size for the convolutional filters in the residual link. padding (int, optional, default=1): Padding size. dilation (int, optional, default=1): Dilation factor. groups (int, optional, default=1): Number of convolutional/linear groups. padding_mode (string, optional, default='zeros'): Type of padding: ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. weight_norm_type (str, optional, default='none'): Type of weight normalization. ``'none'``, ``'spectral'``, ``'weight'`` or ``'weight_demod'``. weight_norm_params (obj, optional, default=None): Parameters of weight normalization. If not ``None``, ``weight_norm_params.__dict__`` will be used as keyword arguments when initializing weight normalization. activation_norm_type (str, optional, default='none'): Type of activation normalization. ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. activation_norm_params (obj, optional, default=None): Parameters of activation normalization. If not ``None``, ``activation_norm_params.__dict__`` will be used as keyword arguments when initializing activation normalization. skip_activation_norm (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies activation norm to the learned shortcut connection. skip_nonlinearity (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies nonlinearity to the learned shortcut connection. nonlinearity (str, optional, default='none'): Type of nonlinear activation function in the residual link. ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. inplace_nonlinearity (bool, optional, default=False): If ``True``, set ``inplace=True`` when initializing the nonlinearity layers. apply_noise (bool, optional, default=False): If ``True``, adds Gaussian noise with learnable magnitude to the convolution output. hidden_channels_equal_out_channels (bool, optional, default=False): If ``True``, set the hidden channel number to be equal to the output channel number. If ``False``, the hidden channel number equals to the smaller of the input channel number and the output channel number. order (str, optional, default='CNACNA'): Order of operations in the residual link. ``'C'``: convolution, ``'N'``: normalization, ``'A'``: nonlinear activation. is_hyper_conv (bool, optional, default=False): If ``True``, use ``HyperConv2d``, otherwise use ``torch.nn.Conv2d``. is_hyper_norm (bool, optional, default=False): If ``True``, use hyper normalizations. learn_shortcut (bool, optional, default=False): If ``True``, always use a convolutional shortcut instead of an identity one, otherwise only use a convolutional one if input and output have different number of channels. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros', weight_norm_type='', weight_norm_params=None, activation_norm_type='', activation_norm_params=None, skip_activation_norm=True, skip_nonlinearity=False, nonlinearity='leakyrelu', inplace_nonlinearity=False, apply_noise=False, hidden_channels_equal_out_channels=False, order='CNACNA', is_hyper_conv=False, is_hyper_norm=False, learn_shortcut=None, clamp=None, output_scale=1): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, is_hyper_conv, is_hyper_norm, HyperConv2dBlock, learn_shortcut, clamp, output_scale)
class _BaseDownResBlock(_BaseResBlock): r"""An abstract class for residual blocks with downsampling. """ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, block, pooling, down_factor, learn_shortcut, clamp=None, output_scale=1): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, block, learn_shortcut, clamp, output_scale) self.pooling = pooling(down_factor) def forward(self, x, *cond_inputs): r""" Args: x (tensor) : Input tensor. cond_inputs (list of tensors) : conditional input. Returns: output (tensor) : Output tensor. """ dx = self.conv_block_0(x, *cond_inputs) dx = self.conv_block_1(dx, *cond_inputs) dx = self.pooling(dx) if self.learn_shortcut: x_shortcut = self.conv_block_s(x, *cond_inputs) else: x_shortcut = x x_shortcut = self.pooling(x_shortcut) output = x_shortcut + dx return self.output_scale * output
[docs]class DownRes2dBlock(_BaseDownResBlock): r"""Residual block for 2D input with downsampling. Args: in_channels (int) : Number of channels in the input tensor. out_channels (int) : Number of channels in the output tensor. kernel_size (int, optional, default=3): Kernel size for the convolutional filters in the residual link. padding (int, optional, default=1): Padding size. dilation (int, optional, default=1): Dilation factor. groups (int, optional, default=1): Number of convolutional/linear groups. padding_mode (string, optional, default='zeros'): Type of padding: ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. weight_norm_type (str, optional, default='none'): Type of weight normalization. ``'none'``, ``'spectral'``, ``'weight'`` or ``'weight_demod'``. weight_norm_params (obj, optional, default=None): Parameters of weight normalization. If not ``None``, ``weight_norm_params.__dict__`` will be used as keyword arguments when initializing weight normalization. activation_norm_type (str, optional, default='none'): Type of activation normalization. ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. activation_norm_params (obj, optional, default=None): Parameters of activation normalization. If not ``None``, ``activation_norm_params.__dict__`` will be used as keyword arguments when initializing activation normalization. skip_activation_norm (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies activation norm to the learned shortcut connection. skip_nonlinearity (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies nonlinearity to the learned shortcut connection. nonlinearity (str, optional, default='none'): Type of nonlinear activation function in the residual link. ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. inplace_nonlinearity (bool, optional, default=False): If ``True``, set ``inplace=True`` when initializing the nonlinearity layers. apply_noise (bool, optional, default=False): If ``True``, adds Gaussian noise with learnable magnitude to the convolution output. hidden_channels_equal_out_channels (bool, optional, default=False): If ``True``, set the hidden channel number to be equal to the output channel number. If ``False``, the hidden channel number equals to the smaller of the input channel number and the output channel number. order (str, optional, default='CNACNA'): Order of operations in the residual link. ``'C'``: convolution, ``'N'``: normalization, ``'A'``: nonlinear activation. pooling (class, optional, default=nn.AvgPool2d): Pytorch pooling layer to be used. down_factor (int, optional, default=2): Downsampling factor. learn_shortcut (bool, optional, default=False): If ``True``, always use a convolutional shortcut instead of an identity one, otherwise only use a convolutional one if input and output have different number of channels. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros', weight_norm_type='none', weight_norm_params=None, activation_norm_type='none', activation_norm_params=None, skip_activation_norm=True, skip_nonlinearity=False, nonlinearity='leakyrelu', inplace_nonlinearity=False, apply_noise=False, hidden_channels_equal_out_channels=False, order='CNACNA', pooling=nn.AvgPool2d, down_factor=2, learn_shortcut=None, clamp=None, output_scale=1): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, Conv2dBlock, pooling, down_factor, learn_shortcut, clamp, output_scale)
class _BaseUpResBlock(_BaseResBlock): r"""An abstract class for residual blocks with upsampling. """ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, block, upsample, up_factor, learn_shortcut, clamp=None, output_scale=1): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, block, learn_shortcut, clamp, output_scale) self.order = order self.upsample = upsample(scale_factor=up_factor) def _get_stride_blur(self): # Upsampling. first_stride, second_stride = self.stride, 1 first_blur, second_blur = self.blur, False shortcut_blur = False shortcut_stride = 1 # if self.upsample == 'blur_deconv': if self.blur: # The shortcut branch uses blur_upsample + stride-1 conv self.upsample = BlurUpsample() else: shortcut_stride = self.stride self.upsample = nn.Upsample(scale_factor=2) return (first_stride, second_stride, shortcut_stride, first_blur, second_blur, shortcut_blur) def forward(self, x, *cond_inputs): r"""Implementation of the up residual block forward function. If the order is 'NAC' for the first residual block, we will first do the activation norm and nonlinearity, in the original resolution. We will then upsample the activation map to a higher resolution. We then do the convolution. It is is other orders, then we first do the whole processing and then upsample. Args: x (tensor) : Input tensor. cond_inputs (list of tensors) : Conditional input. Returns: output (tensor) : Output tensor. """ # In this particular upsample residual block operation, we first # upsample the skip connection. if self.learn_shortcut: x_shortcut = self.upsample(x) x_shortcut = self.conv_block_s(x_shortcut, *cond_inputs) else: x_shortcut = self.upsample(x) if self.order[0:3] == 'NAC': for ix, layer in enumerate(self.conv_block_0.layers.values()): if getattr(layer, 'conditional', False): x = layer(x, *cond_inputs) else: x = layer(x) if ix == 1: x = self.upsample(x) else: x = self.conv_block_0(x, *cond_inputs) x = self.upsample(x) x = self.conv_block_1(x, *cond_inputs) output = x_shortcut + x return self.output_scale * output
[docs]class UpRes2dBlock(_BaseUpResBlock): r"""Residual block for 2D input with downsampling. Args: in_channels (int) : Number of channels in the input tensor. out_channels (int) : Number of channels in the output tensor. kernel_size (int, optional, default=3): Kernel size for the convolutional filters in the residual link. padding (int, optional, default=1): Padding size. dilation (int, optional, default=1): Dilation factor. groups (int, optional, default=1): Number of convolutional/linear groups. padding_mode (string, optional, default='zeros'): Type of padding: ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. weight_norm_type (str, optional, default='none'): Type of weight normalization. ``'none'``, ``'spectral'``, ``'weight'`` or ``'weight_demod'``. weight_norm_params (obj, optional, default=None): Parameters of weight normalization. If not ``None``, ``weight_norm_params.__dict__`` will be used as keyword arguments when initializing weight normalization. activation_norm_type (str, optional, default='none'): Type of activation normalization. ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. activation_norm_params (obj, optional, default=None): Parameters of activation normalization. If not ``None``, ``activation_norm_params.__dict__`` will be used as keyword arguments when initializing activation normalization. skip_activation_norm (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies activation norm to the learned shortcut connection. skip_nonlinearity (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies nonlinearity to the learned shortcut connection. nonlinearity (str, optional, default='none'): Type of nonlinear activation function in the residual link. ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. inplace_nonlinearity (bool, optional, default=False): If ``True``, set ``inplace=True`` when initializing the nonlinearity layers. apply_noise (bool, optional, default=False): If ``True``, adds Gaussian noise with learnable magnitude to the convolution output. hidden_channels_equal_out_channels (bool, optional, default=False): If ``True``, set the hidden channel number to be equal to the output channel number. If ``False``, the hidden channel number equals to the smaller of the input channel number and the output channel number. order (str, optional, default='CNACNA'): Order of operations in the residual link. ``'C'``: convolution, ``'N'``: normalization, ``'A'``: nonlinear activation. upsample (class, optional, default=NearestUpsample): PPytorch upsampling layer to be used. up_factor (int, optional, default=2): Upsampling factor. learn_shortcut (bool, optional, default=False): If ``True``, always use a convolutional shortcut instead of an identity one, otherwise only use a convolutional one if input and output have different number of channels. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros', weight_norm_type='none', weight_norm_params=None, activation_norm_type='none', activation_norm_params=None, skip_activation_norm=True, skip_nonlinearity=False, nonlinearity='leakyrelu', inplace_nonlinearity=False, apply_noise=False, hidden_channels_equal_out_channels=False, order='CNACNA', upsample=NearestUpsample, up_factor=2, learn_shortcut=None, clamp=None, output_scale=1): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, Conv2dBlock, upsample, up_factor, learn_shortcut, clamp, output_scale)
class _BasePartialResBlock(_BaseResBlock): r"""An abstract class for residual blocks with partial convolution. """ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, multi_channel, return_mask, apply_noise, hidden_channels_equal_out_channels, order, block, learn_shortcut, clamp=None, output_scale=1): block = functools.partial(block, multi_channel=multi_channel, return_mask=return_mask) self.partial_conv = True super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, block, learn_shortcut, clamp, output_scale) def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs): r""" Args: x (tensor): Input tensor. cond_inputs (list of tensors) : Conditional input tensors. mask_in (tensor, optional, default=``None``) If not ``None``, it masks the valid input region. kw_cond_inputs (dict) : Keyword conditional inputs. Returns: (tuple): - output (tensor): Output tensor. - mask_out (tensor, optional): Masks the valid output region. """ if self.conv_block_0.layers.conv.return_mask: dx, mask_out = self.conv_block_0(x, *cond_inputs, mask_in=mask_in, **kw_cond_inputs) dx, mask_out = self.conv_block_1(dx, *cond_inputs, mask_in=mask_out, **kw_cond_inputs) else: dx = self.conv_block_0(x, *cond_inputs, mask_in=mask_in, **kw_cond_inputs) dx = self.conv_block_1(dx, *cond_inputs, mask_in=mask_in, **kw_cond_inputs) mask_out = None if self.learn_shortcut: x_shortcut = self.conv_block_s(x, mask_in=mask_in, *cond_inputs, **kw_cond_inputs) if type(x_shortcut) == tuple: x_shortcut, _ = x_shortcut else: x_shortcut = x output = x_shortcut + dx if mask_out is not None: return output, mask_out return self.output_scale * output
[docs]class PartialRes2dBlock(_BasePartialResBlock): r"""Residual block for 2D input with partial convolution. Args: in_channels (int) : Number of channels in the input tensor. out_channels (int) : Number of channels in the output tensor. kernel_size (int, optional, default=3): Kernel size for the convolutional filters in the residual link. padding (int, optional, default=1): Padding size. dilation (int, optional, default=1): Dilation factor. groups (int, optional, default=1): Number of convolutional/linear groups. padding_mode (string, optional, default='zeros'): Type of padding: ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. weight_norm_type (str, optional, default='none'): Type of weight normalization. ``'none'``, ``'spectral'``, ``'weight'`` or ``'weight_demod'``. weight_norm_params (obj, optional, default=None): Parameters of weight normalization. If not ``None``, ``weight_norm_params.__dict__`` will be used as keyword arguments when initializing weight normalization. activation_norm_type (str, optional, default='none'): Type of activation normalization. ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. activation_norm_params (obj, optional, default=None): Parameters of activation normalization. If not ``None``, ``activation_norm_params.__dict__`` will be used as keyword arguments when initializing activation normalization. skip_activation_norm (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies activation norm to the learned shortcut connection. skip_nonlinearity (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies nonlinearity to the learned shortcut connection. nonlinearity (str, optional, default='none'): Type of nonlinear activation function in the residual link. ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. inplace_nonlinearity (bool, optional, default=False): If ``True``, set ``inplace=True`` when initializing the nonlinearity layers. apply_noise (bool, optional, default=False): If ``True``, adds Gaussian noise with learnable magnitude to the convolution output. hidden_channels_equal_out_channels (bool, optional, default=False): If ``True``, set the hidden channel number to be equal to the output channel number. If ``False``, the hidden channel number equals to the smaller of the input channel number and the output channel number. order (str, optional, default='CNACNA'): Order of operations in the residual link. ``'C'``: convolution, ``'N'``: normalization, ``'A'``: nonlinear activation. learn_shortcut (bool, optional, default=False): If ``True``, always use a convolutional shortcut instead of an identity one, otherwise only use a convolutional one if input and output have different number of channels. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros', weight_norm_type='none', weight_norm_params=None, activation_norm_type='none', activation_norm_params=None, skip_activation_norm=True, skip_nonlinearity=False, nonlinearity='leakyrelu', inplace_nonlinearity=False, multi_channel=False, return_mask=True, apply_noise=False, hidden_channels_equal_out_channels=False, order='CNACNA', learn_shortcut=None, clamp=None, output_scale=1): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, multi_channel, return_mask, apply_noise, hidden_channels_equal_out_channels, order, PartialConv2dBlock, learn_shortcut, clamp, output_scale)
[docs]class PartialRes3dBlock(_BasePartialResBlock): r"""Residual block for 3D input with partial convolution. Args: in_channels (int) : Number of channels in the input tensor. out_channels (int) : Number of channels in the output tensor. kernel_size (int, optional, default=3): Kernel size for the convolutional filters in the residual link. padding (int, optional, default=1): Padding size. dilation (int, optional, default=1): Dilation factor. groups (int, optional, default=1): Number of convolutional/linear groups. padding_mode (string, optional, default='zeros'): Type of padding: ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. weight_norm_type (str, optional, default='none'): Type of weight normalization. ``'none'``, ``'spectral'``, ``'weight'`` or ``'weight_demod'``. weight_norm_params (obj, optional, default=None): Parameters of weight normalization. If not ``None``, ``weight_norm_params.__dict__`` will be used as keyword arguments when initializing weight normalization. activation_norm_type (str, optional, default='none'): Type of activation normalization. ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. activation_norm_params (obj, optional, default=None): Parameters of activation normalization. If not ``None``, ``activation_norm_params.__dict__`` will be used as keyword arguments when initializing activation normalization. skip_activation_norm (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies activation norm to the learned shortcut connection. skip_nonlinearity (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies nonlinearity to the learned shortcut connection. nonlinearity (str, optional, default='none'): Type of nonlinear activation function in the residual link. ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. inplace_nonlinearity (bool, optional, default=False): If ``True``, set ``inplace=True`` when initializing the nonlinearity layers. apply_noise (bool, optional, default=False): If ``True``, adds Gaussian noise with learnable magnitude to the convolution output. hidden_channels_equal_out_channels (bool, optional, default=False): If ``True``, set the hidden channel number to be equal to the output channel number. If ``False``, the hidden channel number equals to the smaller of the input channel number and the output channel number. order (str, optional, default='CNACNA'): Order of operations in the residual link. ``'C'``: convolution, ``'N'``: normalization, ``'A'``: nonlinear activation. learn_shortcut (bool, optional, default=False): If ``True``, always use a convolutional shortcut instead of an identity one, otherwise only use a convolutional one if input and output have different number of channels. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros', weight_norm_type='none', weight_norm_params=None, activation_norm_type='none', activation_norm_params=None, skip_activation_norm=True, skip_nonlinearity=False, nonlinearity='leakyrelu', inplace_nonlinearity=False, multi_channel=False, return_mask=True, apply_noise=False, hidden_channels_equal_out_channels=False, order='CNACNA', learn_shortcut=None, clamp=None, output_scale=1): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, multi_channel, return_mask, apply_noise, hidden_channels_equal_out_channels, order, PartialConv3dBlock, learn_shortcut, clamp, output_scale)
class _BaseMultiOutResBlock(_BaseResBlock): r"""An abstract class for residual blocks that can returns multiple outputs. """ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, block, learn_shortcut, clamp=None, output_scale=1, blur=False, upsample_first=True): self.multiple_outputs = True super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, block, learn_shortcut, clamp, output_scale, blur=blur, upsample_first=upsample_first) def forward(self, x, *cond_inputs): r""" Args: x (tensor): Input tensor. cond_inputs (list of tensors) : Conditional input tensors. Returns: (tuple): - output (tensor): Output tensor. - aux_outputs_0 (tensor): Auxiliary output of the first block. - aux_outputs_1 (tensor): Auxiliary output of the second block. """ dx, aux_outputs_0 = self.conv_block_0(x, *cond_inputs) dx, aux_outputs_1 = self.conv_block_1(dx, *cond_inputs) if self.learn_shortcut: # We are not using the auxiliary outputs of self.conv_block_s. x_shortcut, _ = self.conv_block_s(x, *cond_inputs) else: x_shortcut = x output = x_shortcut + dx return self.output_scale * output, aux_outputs_0, aux_outputs_1
[docs]class MultiOutRes2dBlock(_BaseMultiOutResBlock): r"""Residual block for 2D input. It can return multiple outputs, if some layers in the block return more than one output. Args: in_channels (int) : Number of channels in the input tensor. out_channels (int) : Number of channels in the output tensor. kernel_size (int, optional, default=3): Kernel size for the convolutional filters in the residual link. padding (int, optional, default=1): Padding size. dilation (int, optional, default=1): Dilation factor. groups (int, optional, default=1): Number of convolutional/linear groups. padding_mode (string, optional, default='zeros'): Type of padding: ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. weight_norm_type (str, optional, default='none'): Type of weight normalization. ``'none'``, ``'spectral'``, ``'weight'`` or ``'weight_demod'``. weight_norm_params (obj, optional, default=None): Parameters of weight normalization. If not ``None``, ``weight_norm_params.__dict__`` will be used as keyword arguments when initializing weight normalization. activation_norm_type (str, optional, default='none'): Type of activation normalization. ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. activation_norm_params (obj, optional, default=None): Parameters of activation normalization. If not ``None``, ``activation_norm_params.__dict__`` will be used as keyword arguments when initializing activation normalization. skip_activation_norm (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies activation norm to the learned shortcut connection. skip_nonlinearity (bool, optional, default=True): If ``True`` and ``learn_shortcut`` is also ``True``, applies nonlinearity to the learned shortcut connection. nonlinearity (str, optional, default='none'): Type of nonlinear activation function in the residual link. ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. inplace_nonlinearity (bool, optional, default=False): If ``True``, set ``inplace=True`` when initializing the nonlinearity layers. apply_noise (bool, optional, default=False): If ``True``, adds Gaussian noise with learnable magnitude to the convolution output. hidden_channels_equal_out_channels (bool, optional, default=False): If ``True``, set the hidden channel number to be equal to the output channel number. If ``False``, the hidden channel number equals to the smaller of the input channel number and the output channel number. order (str, optional, default='CNACNA'): Order of operations in the residual link. ``'C'``: convolution, ``'N'``: normalization, ``'A'``: nonlinear activation. learn_shortcut (bool, optional, default=False): If ``True``, always use a convolutional shortcut instead of an identity one, otherwise only use a convolutional one if input and output have different number of channels. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros', weight_norm_type='none', weight_norm_params=None, activation_norm_type='none', activation_norm_params=None, skip_activation_norm=True, skip_nonlinearity=False, nonlinearity='leakyrelu', inplace_nonlinearity=False, apply_noise=False, hidden_channels_equal_out_channels=False, order='CNACNA', learn_shortcut=None, clamp=None, output_scale=1, blur=False, upsample_first=True): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, skip_activation_norm, skip_nonlinearity, nonlinearity, inplace_nonlinearity, apply_noise, hidden_channels_equal_out_channels, order, MultiOutConv2dBlock, learn_shortcut, clamp, output_scale, blur=blur, upsample_first=upsample_first)