Source code for imaginaire.layers.nonlinearity

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out
import torch
from torch import nn
import torch.nn.functional as F
from imaginaire.third_party.bias_act.bias_act import FusedNonlinearity
[docs]class ScaledLeakyReLU(nn.Module): def __init__(self, negative_slope=0.2, scale=2 ** 0.5, inplace=False): super().__init__() self.negative_slope = negative_slope self.scale = scale self.inplace = inplace
[docs] def forward(self, x): return F.leaky_relu(x, self.negative_slope, inplace=self.inplace) * self.scale
# return _fused_scaled_leakyrelu(x, self.negative_slope, self.inplace, self.scale) # @torch.jit.script # def _fused_scaled_leakyrelu(x: torch.Tensor, negative_slope: float, inplace: bool, scale: float): # return F.leaky_relu(x, negative_slope, inplace=inplace) * scale
[docs]def get_nonlinearity_layer(nonlinearity_type, inplace, **kwargs): r"""Return a nonlinearity layer. Args: nonlinearity_type (str): Type of nonlinear activation function. ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. inplace (bool): If ``True``, set ``inplace=True`` when initializing the nonlinearity layer. """ if nonlinearity_type.startswith('fused'): nonlinearity = FusedNonlinearity(nonlinearity=nonlinearity_type[6:], **kwargs) elif nonlinearity_type == 'relu': nonlinearity = nn.ReLU(inplace=inplace) elif nonlinearity_type == 'leakyrelu': nonlinearity = nn.LeakyReLU(0.2, inplace=inplace) elif nonlinearity_type == 'scaled_leakyrelu': nonlinearity = ScaledLeakyReLU(0.2, inplace=inplace) import imaginaire.config if imaginaire.config.USE_JIT: nonlinearity = torch.jit.script(nonlinearity) elif nonlinearity_type == 'prelu': nonlinearity = nn.PReLU() elif nonlinearity_type == 'tanh': nonlinearity = nn.Tanh() elif nonlinearity_type == 'sigmoid': nonlinearity = nn.Sigmoid() elif nonlinearity_type.startswith('softmax'): dim = nonlinearity_type.split(',')[1] if ',' in nonlinearity_type else 1 nonlinearity = nn.Softmax(dim=int(dim)) elif nonlinearity_type == 'none' or nonlinearity_type == '': nonlinearity = None else: raise ValueError('Nonlinearity %s is not recognized' % nonlinearity_type) return nonlinearity