Source code for imaginaire.losses.dict

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out
import torch.nn as nn
[docs]class DictLoss(nn.Module): def __init__(self, criterion='l1'): super(DictLoss, self).__init__() if criterion == 'l1': self.criterion = nn.L1Loss() elif criterion == 'l2' or criterion == 'mse': self.criterion = nn.MSELoss() else: raise ValueError('Criterion %s is not recognized' % criterion)
[docs] def forward(self, fake, real): """Return the target vector for the l1/l2 loss computation. Args: fake (dict, list or tuple): Discriminator features of fake images. real (dict, list or tuple): Discriminator features of real images. Returns: loss (tensor): Loss value. """ loss = 0 if type(fake) == dict: for key in fake.keys(): loss += self.criterion(fake[key], real[key].detach()) elif type(fake) == list or type(fake) == tuple: for f, r in zip(fake, real): loss += self.criterion(f, r.detach()) else: loss += self.criterion(fake, real.detach()) return loss