Source code for imaginaire.trainers.unit

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch
from torch import nn

from imaginaire.evaluation import compute_fid
from imaginaire.losses import GANLoss, PerceptualLoss  # GaussianKLLoss
from imaginaire.trainers.base import BaseTrainer


[docs]class Trainer(BaseTrainer): r"""Reimplementation of the UNIT (https://arxiv.org/abs/1703.00848) algorithm. Args: cfg (obj): Global configuration. net_G (obj): Generator network. net_D (obj): Discriminator network. opt_G (obj): Optimizer for the generator network. opt_D (obj): Optimizer for the discriminator network. sch_G (obj): Scheduler for the generator optimizer. sch_D (obj): Scheduler for the discriminator optimizer. train_data_loader (obj): Train data loader. val_data_loader (obj): Validation data loader. """ def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader): super().__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader) self.best_fid_a = None self.best_fid_b = None def _init_loss(self, cfg): r"""Initialize loss terms. In UNIT, we have several loss terms including the GAN loss, the image reconstruction loss, the cycle reconstruction loss, and the gaussian kl loss. We also have an optional perceptual loss. A user can choose to have the gradient penalty loss too. Args: cfg (obj): Global configuration. """ self.criteria['gan'] = GANLoss(cfg.trainer.gan_mode) # self.criteria['gaussian_kl'] = GaussianKLLoss() self.criteria['image_recon'] = nn.L1Loss() self.criteria['cycle_recon'] = nn.L1Loss() if getattr(cfg.trainer.loss_weight, 'perceptual', 0) > 0: self.criteria['perceptual'] = \ PerceptualLoss(network=cfg.trainer.perceptual_mode, layers=cfg.trainer.perceptual_layers) for loss_name, loss_weight in cfg.trainer.loss_weight.__dict__.items(): if loss_weight > 0: self.weights[loss_name] = loss_weight
[docs] def gen_forward(self, data): r"""Compute the loss for UNIT generator. Args: data (dict): Training data at the current iteration. """ cycle_recon = 'cycle_recon' in self.weights perceptual = 'perceptual' in self.weights net_G_output = self.net_G(data, cycle_recon=cycle_recon) net_D_output = self.net_D(data, net_G_output, real=False) self._time_before_loss() # GAN loss self.gen_losses['gan_a'] = self.criteria['gan']( net_D_output['out_ba'], True, dis_update=False) self.gen_losses['gan_b'] = self.criteria['gan']( net_D_output['out_ab'], True, dis_update=False) self.gen_losses['gan'] = \ self.gen_losses['gan_a'] + self.gen_losses['gan_b'] # Perceptual loss if perceptual: self.gen_losses['perceptual_a'] = \ self.criteria['perceptual'](net_G_output['images_ab'], data['images_a']) self.gen_losses['perceptual_b'] = \ self.criteria['perceptual'](net_G_output['images_ba'], data['images_b']) self.gen_losses['perceptual'] = \ self.gen_losses['perceptual_a'] + \ self.gen_losses['perceptual_b'] # Image reconstruction loss self.gen_losses['image_recon'] = \ self.criteria['image_recon'](net_G_output['images_aa'], data['images_a']) + \ self.criteria['image_recon'](net_G_output['images_bb'], data['images_b']) """ # KL loss self.gen_losses['gaussian_kl'] = \ self.criteria['gaussian_kl'](net_G_output['content_mu_a']) + \ self.criteria['gaussian_kl'](net_G_output['content_mu_b']) + \ self.criteria['gaussian_kl'](net_G_output['content_mu_a_recon']) + \ self.criteria['gaussian_kl'](net_G_output['content_mu_b_recon']) """ # Cycle reconstruction loss if cycle_recon: self.gen_losses['cycle_recon_aba'] = \ self.criteria['cycle_recon'](net_G_output['images_aba'], data['images_a']) self.gen_losses['cycle_recon_bab'] = \ self.criteria['cycle_recon'](net_G_output['images_bab'], data['images_b']) self.gen_losses['cycle_recon'] = \ self.gen_losses['cycle_recon_aba'] + \ self.gen_losses['cycle_recon_bab'] # Compute total loss total_loss = self._get_total_loss(gen_forward=True) return total_loss
[docs] def dis_forward(self, data): r"""Compute the loss for UNIT discriminator. Args: data (dict): Training data at the current iteration. """ with torch.no_grad(): net_G_output = self.net_G(data, image_recon=False, cycle_recon=False) net_G_output['images_ba'].requires_grad = True net_G_output['images_ab'].requires_grad = True net_D_output = self.net_D(data, net_G_output) self._time_before_loss() # GAN loss. self.dis_losses['gan_a'] = \ self.criteria['gan'](net_D_output['out_a'], True) + \ self.criteria['gan'](net_D_output['out_ba'], False) self.dis_losses['gan_b'] = \ self.criteria['gan'](net_D_output['out_b'], True) + \ self.criteria['gan'](net_D_output['out_ab'], False) self.dis_losses['gan'] = \ self.dis_losses['gan_a'] + self.dis_losses['gan_b'] # Compute total loss total_loss = self._get_total_loss(gen_forward=False) return total_loss
def _get_visualizations(self, data): r"""Compute visualization image. Args: data (dict): The current batch. """ if self.cfg.trainer.model_average_config.enabled: net_G_for_evaluation = self.net_G.module.averaged_model else: net_G_for_evaluation = self.net_G with torch.no_grad(): net_G_output = net_G_for_evaluation(data) vis_images = [data['images_a'], data['images_b'], net_G_output['images_aa'], net_G_output['images_bb'], net_G_output['images_ab'], net_G_output['images_ba'], net_G_output['images_aba'], net_G_output['images_bab']] return vis_images
[docs] def write_metrics(self): r"""Compute metrics and save them to tensorboard""" cur_fid_a, cur_fid_b = self._compute_fid() if self.best_fid_a is not None: self.best_fid_a = min(self.best_fid_a, cur_fid_a) else: self.best_fid_a = cur_fid_a if self.best_fid_b is not None: self.best_fid_b = min(self.best_fid_b, cur_fid_b) else: self.best_fid_b = cur_fid_b self._write_to_meters({'FID_a': cur_fid_a, 'best_FID_a': self.best_fid_a, 'FID_b': cur_fid_b, 'best_FID_b': self.best_fid_b}, self.metric_meters) self._flush_meters(self.metric_meters)
def _compute_fid(self): r"""Compute FID for both domains. """ self.net_G.eval() if self.cfg.trainer.model_average_config.enabled: net_G_for_evaluation = self.net_G.module.averaged_model else: net_G_for_evaluation = self.net_G fid_a_path = self._get_save_path('fid_a', 'npy') fid_b_path = self._get_save_path('fid_b', 'npy') fid_value_a = compute_fid(fid_a_path, self.val_data_loader, net_G_for_evaluation, 'images_a', 'images_ba') fid_value_b = compute_fid(fid_b_path, self.val_data_loader, net_G_for_evaluation, 'images_b', 'images_ab') print('Epoch {:05}, Iteration {:09}, FID a {}, FID b {}'.format( self.current_epoch, self.current_iteration, fid_value_a, fid_value_b)) return fid_value_a, fid_value_b