Source code for imaginaire.evaluation.lpips

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
from collections import namedtuple

import torch
from torch import nn, distributed as dist
import torchvision.models as tv
from torch.distributed import barrier

from imaginaire.utils.distributed import is_local_master


[docs]def get_lpips_model(): if dist.is_initialized() and not is_local_master(): # Make sure only the first process in distributed training downloads the model, and the others use the cache. barrier() model = LPIPSNet().cuda() if dist.is_initialized() and is_local_master(): # Make sure only the first process in distributed training downloads the model, and the others use the cache. barrier() return model
# Learned perceptual network, modified from https://github.com/richzhang/PerceptualSimilarity
[docs]def normalize_tensor(in_feat, eps=1e-5): norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True) + eps) return in_feat / (norm_factor + eps)
[docs]class NetLinLayer(nn.Module): """ A single linear layer used as placeholder for LPIPS learnt weights """ def __init__(self, dim): super(NetLinLayer, self).__init__() self.weight = nn.Parameter(torch.zeros(1, dim, 1, 1))
[docs] def forward(self, inp): out = self.weight * inp return out
[docs]class ScalingLayer(nn.Module): # For rescaling the input to vgg16 def __init__(self): super(ScalingLayer, self).__init__() self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
[docs] def forward(self, inp): return (inp - self.shift) / self.scale
[docs]class LPIPSNet(nn.Module): def __init__(self): super().__init__() self.model = LPNet()
[docs] @torch.no_grad() def forward(self, fake_images, fake_images_another, align_corners=True): features, shape = self._forward_single(fake_images) features_another, _ = self._forward_single(fake_images_another) result = 0 for i, g_feat in enumerate(features): cur_diff = torch.sum((g_feat - features_another[i]) ** 2, dim=1) / (shape[i] ** 2) result += cur_diff return result
def _forward_single(self, images): return self.model(torch.clamp(images, 0, 1))
[docs]class LPNet(nn.Module): def __init__(self): super(LPNet, self).__init__() self.scaling_layer = ScalingLayer() self.net = vgg16(pretrained=True, requires_grad=False) self.L = 5 dims = [64, 128, 256, 512, 512] self.lins = nn.ModuleList([NetLinLayer(dims[i]) for i in range(self.L)]) weights = torch.hub.load_state_dict_from_url( 'https://github.com/niopeng/CAM-Net/raw/main/code/models/weights/v0.1/vgg.pth' ) for i in range(self.L): self.lins[i].weight.data = torch.sqrt(weights["lin%d.model.1.weight" % i])
[docs] def forward(self, in0, avg=False): in0 = 2 * in0 - 1 in0_input = self.scaling_layer(in0) outs0 = self.net.forward(in0_input) feats0 = {} shapes = [] res = [] for kk in range(self.L): feats0[kk] = normalize_tensor(outs0[kk]) if avg: res = [self.lins[kk](feats0[kk]).mean([2, 3], keepdim=False) for kk in range(self.L)] else: for kk in range(self.L): cur_res = self.lins[kk](feats0[kk]) shapes.append(cur_res.shape[-1]) res.append(cur_res.reshape(cur_res.shape[0], -1)) return res, shapes
[docs]class vgg16(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(vgg16, self).__init__() vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False
[docs] def forward(self, x): h = self.slice1(x) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out