Source code for imaginaire.utils.trainer

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import importlib
import random
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.optim import SGD, Adam, RMSprop, lr_scheduler

from imaginaire.optimizers import Fromage, Madam
from imaginaire.utils.distributed import get_rank, get_world_size
from imaginaire.utils.distributed import master_only_print as print
from imaginaire.utils.init_weight import weights_init, weights_rescale
from imaginaire.utils.model_average import ModelAverage


[docs]def set_random_seed(seed, by_rank=False): r"""Set random seeds for everything. Args: seed (int): Random seed. by_rank (bool): """ if by_rank: seed += get_rank() print(f"Using random seed {seed}") random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed)
[docs]def get_trainer(cfg, net_G, net_D=None, opt_G=None, opt_D=None, sch_G=None, sch_D=None, train_data_loader=None, val_data_loader=None): """Return the trainer object. Args: cfg (Config): Loaded config object. net_G (obj): Generator network object. net_D (obj): Discriminator network object. opt_G (obj): Generator optimizer object. opt_D (obj): Discriminator optimizer object. sch_G (obj): Generator optimizer scheduler object. sch_D (obj): Discriminator optimizer scheduler object. train_data_loader (obj): Train data loader. val_data_loader (obj): Validation data loader. Returns: (obj): Trainer object. """ trainer_lib = importlib.import_module(cfg.trainer.type) trainer = trainer_lib.Trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader) return trainer
[docs]def get_model_optimizer_and_scheduler(cfg, seed=0): r"""Return the networks, the optimizers, and the schedulers. We will first set the random seed to a fixed value so that each GPU copy will be initialized to have the same network weights. We will then use different random seeds for different GPUs. After this we will wrap the generator with a moving average model if applicable. It is followed by getting the optimizers and data distributed data parallel wrapping. Args: cfg (obj): Global configuration. seed (int): Random seed. Returns: (dict): - net_G (obj): Generator network object. - net_D (obj): Discriminator network object. - opt_G (obj): Generator optimizer object. - opt_D (obj): Discriminator optimizer object. - sch_G (obj): Generator optimizer scheduler object. - sch_D (obj): Discriminator optimizer scheduler object. """ # We first set the random seed to be the same so that we initialize each # copy of the network in exactly the same way so that they have the same # weights and other parameters. The true seed will be the seed. set_random_seed(seed, by_rank=False) # Construct networks lib_G = importlib.import_module(cfg.gen.type) lib_D = importlib.import_module(cfg.dis.type) net_G = lib_G.Generator(cfg.gen, cfg.data) net_D = lib_D.Discriminator(cfg.dis, cfg.data) print('Initialize net_G and net_D weights using ' 'type: {} gain: {}'.format(cfg.trainer.init.type, cfg.trainer.init.gain)) init_bias = getattr(cfg.trainer.init, 'bias', None) net_G.apply(weights_init( cfg.trainer.init.type, cfg.trainer.init.gain, init_bias)) net_D.apply(weights_init( cfg.trainer.init.type, cfg.trainer.init.gain, init_bias)) net_G.apply(weights_rescale()) net_D.apply(weights_rescale()) # for name, p in net_G.named_parameters(): # if 'modulation' in name and 'bias' in name: # nn.init.constant_(p.data, 1.) net_G = net_G.to('cuda') net_D = net_D.to('cuda') # Different GPU copies of the same model will receive noises # initialized with different random seeds (if applicable) thanks to the # set_random_seed command (GPU #K has random seed = args.seed + K). set_random_seed(seed, by_rank=True) print('net_G parameter count: {:,}'.format(_calculate_model_size(net_G))) print('net_D parameter count: {:,}'.format(_calculate_model_size(net_D))) # Optimizer opt_G = get_optimizer(cfg.gen_opt, net_G) opt_D = get_optimizer(cfg.dis_opt, net_D) net_G, net_D, opt_G, opt_D = \ wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D) # Scheduler sch_G = get_scheduler(cfg.gen_opt, opt_G) sch_D = get_scheduler(cfg.dis_opt, opt_D) return net_G, net_D, opt_G, opt_D, sch_G, sch_D
[docs]def wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D): r"""Wrap the networks and the optimizers with AMP DDP and (optionally) model average. Args: cfg (obj): Global configuration. net_G (obj): Generator network object. net_D (obj): Discriminator network object. opt_G (obj): Generator optimizer object. opt_D (obj): Discriminator optimizer object. Returns: (dict): - net_G (obj): Generator network object. - net_D (obj): Discriminator network object. - opt_G (obj): Generator optimizer object. - opt_D (obj): Discriminator optimizer object. """ # Apply model average wrapper. if cfg.trainer.model_average_config.enabled: if hasattr(cfg.trainer.model_average_config, 'g_smooth_img'): # Specifies half-life of the running average of generator weights. cfg.trainer.model_average_config.beta = \ 0.5 ** (cfg.data.train.batch_size * get_world_size() / cfg.trainer.model_average_config.g_smooth_img) print(f"EMA Decay Factor: {cfg.trainer.model_average_config.beta}") net_G = ModelAverage(net_G, cfg.trainer.model_average_config.beta, cfg.trainer.model_average_config.start_iteration, cfg.trainer.model_average_config.remove_sn) if cfg.trainer.model_average_config.enabled: net_G_module = net_G.module else: net_G_module = net_G if hasattr(net_G_module, 'custom_init'): net_G_module.custom_init() net_G = _wrap_model(cfg, net_G) net_D = _wrap_model(cfg, net_D) return net_G, net_D, opt_G, opt_D
def _calculate_model_size(model): r"""Calculate number of parameters in a PyTorch network. Args: model (obj): PyTorch network. Returns: (int): Number of parameters. """ return sum(p.numel() for p in model.parameters() if p.requires_grad)
[docs]class WrappedModel(nn.Module): r"""Dummy wrapping the module. """ def __init__(self, module): super(WrappedModel, self).__init__() self.module = module
[docs] def forward(self, *args, **kwargs): r"""PyTorch module forward function overload.""" return self.module(*args, **kwargs)
def _wrap_model(cfg, model): r"""Wrap a model for distributed data parallel training. Args: model (obj): PyTorch network model. Returns: (obj): Wrapped PyTorch network model. """ if torch.distributed.is_available() and dist.is_initialized(): # ddp = cfg.trainer.distributed_data_parallel find_unused_parameters = cfg.trainer.distributed_data_parallel_params.find_unused_parameters return torch.nn.parallel.DistributedDataParallel( model, device_ids=[cfg.local_rank], output_device=cfg.local_rank, find_unused_parameters=find_unused_parameters, broadcast_buffers=False ) # if ddp == 'pytorch': # return torch.nn.parallel.DistributedDataParallel( # model, # device_ids=[cfg.local_rank], # output_device=cfg.local_rank, # find_unused_parameters=find_unused_parameters, # broadcast_buffers=False) # else: # delay_allreduce = cfg.trainer.delay_allreduce # return apex.parallel.DistributedDataParallel( # model, delay_allreduce=delay_allreduce) else: return WrappedModel(model)
[docs]def get_scheduler(cfg_opt, opt): """Return the scheduler object. Args: cfg_opt (obj): Config for the specific optimization module (gen/dis). opt (obj): PyTorch optimizer object. Returns: (obj): Scheduler """ if cfg_opt.lr_policy.type == 'step': scheduler = lr_scheduler.StepLR( opt, step_size=cfg_opt.lr_policy.step_size, gamma=cfg_opt.lr_policy.gamma) elif cfg_opt.lr_policy.type == 'constant': scheduler = lr_scheduler.LambdaLR(opt, lambda x: 1) elif cfg_opt.lr_policy.type == 'linear': # Start linear decay from here. decay_start = cfg_opt.lr_policy.decay_start # End linear decay here. # Continue to train using the lowest learning rate till the end. decay_end = cfg_opt.lr_policy.decay_end # Lowest learning rate multiplier. decay_target = cfg_opt.lr_policy.decay_target def sch(x): return min( max(((x - decay_start) * decay_target + decay_end - x) / ( decay_end - decay_start ), decay_target), 1. ) scheduler = lr_scheduler.LambdaLR(opt, lambda x: sch(x)) else: return NotImplementedError('Learning rate policy {} not implemented.'. format(cfg_opt.lr_policy.type)) return scheduler
[docs]def get_optimizer(cfg_opt, net): r"""Return the scheduler object. Args: cfg_opt (obj): Config for the specific optimization module (gen/dis). net (obj): PyTorch network object. Returns: (obj): Pytorch optimizer """ if hasattr(net, 'get_param_groups'): # Allow the network to use different hyper-parameters (e.g., learning # rate) for different parameters. params = net.get_param_groups(cfg_opt) else: params = net.parameters() return get_optimizer_for_params(cfg_opt, params)
[docs]def get_optimizer_for_params(cfg_opt, params): r"""Return the scheduler object. Args: cfg_opt (obj): Config for the specific optimization module (gen/dis). params (obj): Parameters to be trained by the parameters. Returns: (obj): Optimizer """ # We will use fuse optimizers by default. fused_opt = cfg_opt.fused_opt try: from apex.optimizers import FusedAdam except: # noqa fused_opt = False if cfg_opt.type == 'adam': if fused_opt: opt = FusedAdam(params, lr=cfg_opt.lr, eps=cfg_opt.eps, betas=(cfg_opt.adam_beta1, cfg_opt.adam_beta2)) else: opt = Adam(params, lr=cfg_opt.lr, eps=cfg_opt.eps, betas=(cfg_opt.adam_beta1, cfg_opt.adam_beta2)) elif cfg_opt.type == 'madam': g_bound = getattr(cfg_opt, 'g_bound', None) opt = Madam(params, lr=cfg_opt.lr, scale=cfg_opt.scale, g_bound=g_bound) elif cfg_opt.type == 'fromage': opt = Fromage(params, lr=cfg_opt.lr) elif cfg_opt.type == 'rmsprop': opt = RMSprop(params, lr=cfg_opt.lr, eps=cfg_opt.eps, weight_decay=cfg_opt.weight_decay) elif cfg_opt.type == 'sgd': if fused_opt: from apex.optimizers import FusedSGD opt = FusedSGD(params, lr=cfg_opt.lr, momentum=cfg_opt.momentum, weight_decay=cfg_opt.weight_decay) else: opt = SGD(params, lr=cfg_opt.lr, momentum=cfg_opt.momentum, weight_decay=cfg_opt.weight_decay) else: raise NotImplementedError( 'Optimizer {} is not yet implemented.'.format(cfg_opt.type)) return opt