Source code for imaginaire.losses.perceptual

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch
import torch.nn.functional as F
import torchvision
from torch import nn, distributed as dist

from imaginaire.losses.info_nce import InfoNCELoss
from imaginaire.utils.distributed import master_only_print as print, \
    is_local_master
from imaginaire.utils.misc import apply_imagenet_normalization, to_float


[docs]class PerceptualLoss(nn.Module): r"""Perceptual loss initialization. Args: network (str) : The name of the loss network: 'vgg16' | 'vgg19'. layers (str or list of str) : The layers used to compute the loss. weights (float or list of float : The loss weights of each layer. criterion (str): The type of distance function: 'l1' | 'l2'. resize (bool) : If ``True``, resize the input images to 224x224. resize_mode (str): Algorithm used for resizing. num_scales (int): The loss will be evaluated at original size and this many times downsampled sizes. per_sample_weight (bool): Output loss for individual samples in the batch instead of mean loss. """ def __init__(self, network='vgg19', layers='relu_4_1', weights=None, criterion='l1', resize=False, resize_mode='bilinear', num_scales=1, per_sample_weight=False, info_nce_temperature=0.07, info_nce_gather_distributed=True, info_nce_learn_temperature=True, info_nce_flatten=True): super().__init__() if isinstance(layers, str): layers = [layers] if weights is None: weights = [1.] * len(layers) elif isinstance(layers, float) or isinstance(layers, int): weights = [weights] if dist.is_initialized() and not is_local_master(): # Make sure only the first process in distributed training downloads # the model, and the others will use the cache # noinspection PyUnresolvedReferences torch.distributed.barrier() assert len(layers) == len(weights), \ 'The number of layers (%s) must be equal to ' \ 'the number of weights (%s).' % (len(layers), len(weights)) if network == 'vgg19': self.model = _vgg19(layers) elif network == 'vgg16': self.model = _vgg16(layers) elif network == 'alexnet': self.model = _alexnet(layers) elif network == 'inception_v3': self.model = _inception_v3(layers) elif network == 'resnet50': self.model = _resnet50(layers) elif network == 'robust_resnet50': self.model = _robust_resnet50(layers) elif network == 'vgg_face_dag': self.model = _vgg_face_dag(layers) else: raise ValueError('Network %s is not recognized' % network) if dist.is_initialized() and is_local_master(): # Make sure only the first process in distributed training downloads # the model, and the others will use the cache # noinspection PyUnresolvedReferences torch.distributed.barrier() self.num_scales = num_scales self.layers = layers self.weights = weights reduction = 'mean' if not per_sample_weight else 'none' if criterion == 'l1': self.criterion = nn.L1Loss(reduction=reduction) elif criterion == 'l2' or criterion == 'mse': self.criterion = nn.MSELoss(reduction=reduction) elif criterion == 'info_nce': self.criterion = InfoNCELoss( temperature=info_nce_temperature, gather_distributed=info_nce_gather_distributed, learn_temperature=info_nce_learn_temperature, flatten=info_nce_flatten, single_direction=True ) else: raise ValueError('Criterion %s is not recognized' % criterion) self.resize = resize self.resize_mode = resize_mode print('Perceptual loss:') print('\tMode: {}'.format(network))
[docs] def forward(self, inp, target, per_sample_weights=None): r"""Perceptual loss forward. Args: inp (4D tensor) : Input tensor. target (4D tensor) : Ground truth tensor, same shape as the input. per_sample_weight (bool): Output loss for individual samples in the batch instead of mean loss. Returns: (scalar tensor) : The perceptual loss. """ if not torch.is_autocast_enabled(): inp, target = to_float([inp, target]) # Perceptual loss should operate in eval mode by default. self.model.eval() inp, target = apply_imagenet_normalization(inp), apply_imagenet_normalization(target) if self.resize: inp = F.interpolate(inp, mode=self.resize_mode, size=(224, 224), align_corners=False) target = F.interpolate(target, mode=self.resize_mode, size=(224, 224), align_corners=False) # Evaluate perceptual loss at each scale. loss = 0 for scale in range(self.num_scales): input_features, target_features = self.model(inp), self.model(target) for layer, weight in zip(self.layers, self.weights): # Example per-layer VGG19 loss values after applying # [0.03125, 0.0625, 0.125, 0.25, 1.0] weighting. # relu_1_1, 0.014698 # relu_2_1, 0.085817 # relu_3_1, 0.349977 # relu_4_1, 0.544188 # relu_5_1, 0.906261 # print('%s, %f' % ( # layer, # weight * self.criterion( # input_features[layer], # target_features[ # layer].detach()).item())) l_tmp = self.criterion(input_features[layer], target_features[layer].detach()) if per_sample_weights is not None: l_tmp = l_tmp.mean(1).mean(1).mean(1) loss += weight * l_tmp # Downsample the input and target. if scale != self.num_scales - 1: inp = F.interpolate( inp, mode=self.resize_mode, scale_factor=0.5, align_corners=False, recompute_scale_factor=True) target = F.interpolate( target, mode=self.resize_mode, scale_factor=0.5, align_corners=False, recompute_scale_factor=True) return loss.float()
class _PerceptualNetwork(nn.Module): r"""The network that extracts features to compute the perceptual loss. Args: network (nn.Sequential) : The network that extracts features. layer_name_mapping (dict) : The dictionary that maps a layer's index to its name. layers (list of str): The list of layer names that we are using. """ def __init__(self, network, layer_name_mapping, layers): super().__init__() assert isinstance(network, nn.Sequential), \ 'The network needs to be of type "nn.Sequential".' self.network = network self.layer_name_mapping = layer_name_mapping self.layers = layers for param in self.parameters(): param.requires_grad = False def forward(self, x): r"""Extract perceptual features.""" output = {} for i, layer in enumerate(self.network): x = layer(x) layer_name = self.layer_name_mapping.get(i, None) if layer_name in self.layers: # If the current layer is used by the perceptual loss. output[layer_name] = x return output def _vgg19(layers): r"""Get vgg19 layers""" vgg = torchvision.models.vgg19(pretrained=True) # network = vgg.features network = torch.nn.Sequential(*(list(vgg.features) + [vgg.avgpool] + [nn.Flatten()] + list(vgg.classifier))) layer_name_mapping = {1: 'relu_1_1', 3: 'relu_1_2', 6: 'relu_2_1', 8: 'relu_2_2', 11: 'relu_3_1', 13: 'relu_3_2', 15: 'relu_3_3', 17: 'relu_3_4', 20: 'relu_4_1', 22: 'relu_4_2', 24: 'relu_4_3', 26: 'relu_4_4', 29: 'relu_5_1', 31: 'relu_5_2', 33: 'relu_5_3', 35: 'relu_5_4', 36: 'pool_5', 42: 'fc_2'} return _PerceptualNetwork(network, layer_name_mapping, layers) def _vgg16(layers): r"""Get vgg16 layers""" network = torchvision.models.vgg16(pretrained=True).features layer_name_mapping = {1: 'relu_1_1', 3: 'relu_1_2', 6: 'relu_2_1', 8: 'relu_2_2', 11: 'relu_3_1', 13: 'relu_3_2', 15: 'relu_3_3', 18: 'relu_4_1', 20: 'relu_4_2', 22: 'relu_4_3', 25: 'relu_5_1'} return _PerceptualNetwork(network, layer_name_mapping, layers) def _alexnet(layers): r"""Get alexnet layers""" network = torchvision.models.alexnet(pretrained=True).features layer_name_mapping = {0: 'conv_1', 1: 'relu_1', 3: 'conv_2', 4: 'relu_2', 6: 'conv_3', 7: 'relu_3', 8: 'conv_4', 9: 'relu_4', 10: 'conv_5', 11: 'relu_5'} return _PerceptualNetwork(network, layer_name_mapping, layers) def _inception_v3(layers): r"""Get inception v3 layers""" inception = torchvision.models.inception_v3(pretrained=True) network = nn.Sequential(inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3, nn.MaxPool2d(kernel_size=3, stride=2), inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2), inception.Mixed_5b, inception.Mixed_5c, inception.Mixed_5d, inception.Mixed_6a, inception.Mixed_6b, inception.Mixed_6c, inception.Mixed_6d, inception.Mixed_6e, inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, nn.AdaptiveAvgPool2d(output_size=(1, 1))) layer_name_mapping = {3: 'pool_1', 6: 'pool_2', 14: 'mixed_6e', 18: 'pool_3'} return _PerceptualNetwork(network, layer_name_mapping, layers) def _resnet50(layers): r"""Get resnet50 layers""" resnet50 = torchvision.models.resnet50(pretrained=True) network = nn.Sequential(resnet50.conv1, resnet50.bn1, resnet50.relu, resnet50.maxpool, resnet50.layer1, resnet50.layer2, resnet50.layer3, resnet50.layer4, resnet50.avgpool) layer_name_mapping = {4: 'layer_1', 5: 'layer_2', 6: 'layer_3', 7: 'layer_4'} return _PerceptualNetwork(network, layer_name_mapping, layers) def _robust_resnet50(layers): r"""Get robust resnet50 layers""" resnet50 = torchvision.models.resnet50(pretrained=False) state_dict = torch.utils.model_zoo.load_url( 'https://andrewilyas.com/ImageNet.pt') new_state_dict = {} for k, v in state_dict['model'].items(): if k.startswith('module.model.'): new_state_dict[k[13:]] = v resnet50.load_state_dict(new_state_dict) network = nn.Sequential(resnet50.conv1, resnet50.bn1, resnet50.relu, resnet50.maxpool, resnet50.layer1, resnet50.layer2, resnet50.layer3, resnet50.layer4, resnet50.avgpool) layer_name_mapping = {4: 'layer_1', 5: 'layer_2', 6: 'layer_3', 7: 'layer_4'} return _PerceptualNetwork(network, layer_name_mapping, layers) def _vgg_face_dag(layers): network = torchvision.models.vgg16(num_classes=2622) state_dict = torch.utils.model_zoo.load_url( 'https://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/' 'vgg_face_dag.pth') feature_layer_name_mapping = { 0: 'conv1_1', 2: 'conv1_2', 5: 'conv2_1', 7: 'conv2_2', 10: 'conv3_1', 12: 'conv3_2', 14: 'conv3_3', 17: 'conv4_1', 19: 'conv4_2', 21: 'conv4_3', 24: 'conv5_1', 26: 'conv5_2', 28: 'conv5_3'} new_state_dict = {} for k, v in feature_layer_name_mapping.items(): new_state_dict['features.' + str(k) + '.weight'] = \ state_dict[v + '.weight'] new_state_dict['features.' + str(k) + '.bias'] = \ state_dict[v + '.bias'] classifier_layer_name_mapping = { 0: 'fc6', 3: 'fc7', 6: 'fc8'} for k, v in classifier_layer_name_mapping.items(): new_state_dict['classifier.' + str(k) + '.weight'] = \ state_dict[v + '.weight'] new_state_dict['classifier.' + str(k) + '.bias'] = \ state_dict[v + '.bias'] network.load_state_dict(new_state_dict) class Flatten(nn.Module): def forward(self, x): return x.view(x.shape[0], -1) layer_name_mapping = { 0: 'conv_1_1', 1: 'relu_1_1', 2: 'conv_1_2', 5: 'conv_2_1', # 1/2 6: 'relu_2_1', 7: 'conv_2_2', 10: 'conv_3_1', # 1/4 11: 'relu_3_1', 12: 'conv_3_2', 14: 'conv_3_3', 17: 'conv_4_1', # 1/8 18: 'relu_4_1', 19: 'conv_4_2', 21: 'conv_4_3', 24: 'conv_5_1', # 1/16 25: 'relu_5_1', 26: 'conv_5_2', 28: 'conv_5_3', 33: 'fc6', 36: 'fc7', 39: 'fc8' } seq_layers = [] for feature in network.features: seq_layers += [feature] seq_layers += [network.avgpool, Flatten()] for classifier in network.classifier: seq_layers += [classifier] network = nn.Sequential(*seq_layers) return _PerceptualNetwork(network, layer_name_mapping, layers)