imaginaire.trainers package

Submodules

imaginaire.trainers.base module

class imaginaire.trainers.base.BaseTrainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]

Bases: object

Base trainer. We expect that all trainers inherit this class.

Parameters
  • cfg (obj) – Global configuration.

  • net_G (obj) – Generator network.

  • net_D (obj) – Discriminator network.

  • opt_G (obj) – Optimizer for the generator network.

  • opt_D (obj) – Optimizer for the discriminator network.

  • sch_G (obj) – Scheduler for the generator optimizer.

  • sch_D (obj) – Scheduler for the discriminator optimizer.

  • train_data_loader (obj) – Train data loader.

  • val_data_loader (obj) – Validation data loader.

dis_forward(data)[source]

Every trainer should implement its own discriminator forward.

dis_update(data)[source]

Update the discriminator.

Parameters

data (dict) – Data used for the current iteration.

end_of_epoch(data, current_epoch, current_iteration)[source]

Things to do after an epoch.

Parameters
  • data (dict) – Data used for the current iteration.

  • current_epoch (int) – Current number of epoch.

  • current_iteration (int) – Current number of iteration.

end_of_iteration(data, current_epoch, current_iteration)[source]

Things to do after an iteration.

Parameters
  • data (dict) – Data used for the current iteration.

  • current_epoch (int) – Current number of epoch.

  • current_iteration (int) – Current number of iteration.

gen_forward(data)[source]

Every trainer should implement its own generator forward.

gen_update(data)[source]

Update the generator.

Parameters

data (dict) – Data used for the current iteration.

load_checkpoint(cfg, checkpoint_path, resume=None, load_sch=True)[source]

Load network weights, optimizer parameters, scheduler parameters from a checkpoint.

Parameters
  • cfg (obj) – Global configuration.

  • checkpoint_path (str) – Path to the checkpoint.

  • resume (bool or None) – If not None, will determine whether or not to load optimizers in addition to network weights.

pre_process(data)[source]

Custom data pre-processing function. Utilize this function if you need to preprocess your data before sending it to the generator and discriminator.

Parameters

data (dict) – Data used for the current iteration.

recalculate_batch_norm_statistics(data_loader, averaged=True)[source]

Update the statistics in the moving average model.

Parameters
  • data_loader (torch.utils.data.DataLoader) – Data loader for estimating the statistics.

  • averaged (Boolean) – True/False, we recalculate batch norm statistics for EMA/regular

save_checkpoint(current_epoch, current_iteration)[source]

Save network weights, optimizer parameters, scheduler parameters to a checkpoint.

save_image(path, data)[source]

Compute visualization images and save them to the disk.

Parameters
  • path (str) – Location of the file.

  • data (dict) – Data used for the current iteration.

start_of_epoch(current_epoch)[source]

Things to do before an epoch.

Parameters

current_epoch (int) – Current number of epoch.

start_of_iteration(data, current_iteration)[source]

Things to do before an iteration.

Parameters
  • data (dict) – Data used for the current iteration.

  • current_iteration (int) – Current number of iteration.

test(data_loader, output_dir, inference_args)[source]

Compute results images for a batch of input data and save the results in the specified folder.

Parameters
  • data_loader (torch.utils.data.DataLoader) – PyTorch dataloader.

  • output_dir (str) – Target location for saving the output image.

write_metrics()[source]

Write metrics to the tensorboard.

imaginaire.trainers.fs_vid2vid module

class imaginaire.trainers.fs_vid2vid.Trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]

Bases: imaginaire.trainers.vid2vid.Trainer

Initialize vid2vid trainer.

Parameters
  • cfg (obj) – Global configuration.

  • net_G (obj) – Generator network.

  • net_D (obj) – Discriminator network.

  • opt_G (obj) – Optimizer for the generator network.

  • opt_D (obj) – Optimizer for the discriminator network.

  • sch_G (obj) – Scheduler for the generator optimizer.

  • sch_D (obj) – Scheduler for the discriminator optimizer.

  • train_data_loader (obj) – Train data loader.

  • val_data_loader (obj) – Validation data loader.

finetune(data, inference_args)[source]

Finetune the model for a few iterations on the inference data.

get_data_t(data, net_G_output, data_prev, t)[source]

Get data at current time frame given the sequence of data.

Parameters
  • data (dict) – Training data for current iteration.

  • net_G_output (dict) – Output of the generator (for previous frame).

  • data_prev (dict) – Data for previous frame.

  • t (int) – Current time.

get_test_output_images(data)[source]

Get the visualization output of test function.

Parameters

data (dict) – Training data at the current iteration.

post_process(data, net_G_output)[source]

Do any postprocessing of the data / output here.

Parameters
  • data (dict) – Training data at the current iteration.

  • net_G_output (dict) – Output of the generator.

pre_process(data)[source]

Do any data pre-processing here.

Parameters

data (dict) – Data used for the current iteration.

save_image(path, data)[source]

Save the output images to path. Note when the generate_raw_output is FALSE. Then, first_net_G_output[‘fake_raw_images’] is None and will not be displayed. In model average mode, we will plot the flow visualization twice.

Parameters
  • path (str) – Save path.

  • data (dict) – Training data for current iteration.

test(test_data_loader, root_output_dir, inference_args)[source]

Run inference on the specified sequence.

Parameters
  • test_data_loader (object) – Test data loader.

  • root_output_dir (str) – Location to dump outputs.

  • inference_args (optional) – Optional args.

imaginaire.trainers.funit module

class imaginaire.trainers.funit.Trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]

Bases: imaginaire.trainers.base.BaseTrainer

Reimplementation of the FUNIT (https://arxiv.org/abs/1905.01723) algorithm.

Parameters
  • cfg (obj) – Global configuration.

  • net_G (obj) – Generator network.

  • net_D (obj) – Discriminator network.

  • opt_G (obj) – Optimizer for the generator network.

  • opt_D (obj) – Optimizer for the discriminator network.

  • sch_G (obj) – Scheduler for the generator optimizer.

  • sch_D (obj) – Scheduler for the discriminator optimizer.

  • train_data_loader (obj) – Train data loader.

  • val_data_loader (obj) – Validation data loader.

dis_forward(data)[source]

Compute the loss for FUNIT discriminator.

Parameters

data (dict) – Training data at the current iteration.

gen_forward(data)[source]

Compute the loss for FUNIT generator.

Parameters

data (dict) – Training data at the current iteration.

write_metrics()[source]

Write metrics to the tensorboard.

imaginaire.trainers.gancraft module

class imaginaire.trainers.gancraft.GauGANLoader(gaugan_cfg)[source]

Bases: object

Manages the SPADE/GauGAN model used to generate pseudo-GTs for training GANcraft.

Parameters

gaugan_cfg (Config) – SPADE configuration.

eval(label, z=None, style_img=None)[source]

Produce output given segmentation and other conditioning inputs. random style will be used if neither z nor style_img is provided.

Parameters
  • label (N x C x H x W tensor) – One-hot segmentation mask of shape.

  • z – Style vector.

  • style_img – Style image.

class imaginaire.trainers.gancraft.Trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]

Bases: imaginaire.trainers.base.BaseTrainer

Initialize GANcraft trainer.

Parameters
  • cfg (Config) – Global configuration.

  • net_G (obj) – Generator network.

  • net_D (obj) – Discriminator network.

  • opt_G (obj) – Optimizer for the generator network.

  • opt_D (obj) – Optimizer for the discriminator network.

  • sch_G (obj) – Scheduler for the generator optimizer.

  • sch_D (obj) – Scheduler for the discriminator optimizer.

  • train_data_loader (obj) – Train data loader.

  • val_data_loader (obj) – Validation data loader.

dis_forward(data)[source]

Compute the loss for GANcraft discriminator.

Parameters

data (dict) – Training data at the current iteration.

gen_forward(data)[source]

Compute the loss for SPADE generator.

Parameters

data (dict) – Training data at the current iteration.

load_checkpoint(cfg, checkpoint_path, resume=None, load_sch=True)[source]

Load network weights, optimizer parameters, scheduler parameters from a checkpoint.

Parameters
  • cfg (obj) – Global configuration.

  • checkpoint_path (str) – Path to the checkpoint.

  • resume (bool or None) – If not None, will determine whether or

  • to load optimizers in addition to network weights. (not) –

test(data_loader, output_dir, inference_args)[source]

Compute results images for a batch of input data and save the results in the specified folder.

Parameters
  • data_loader (torch.utils.data.DataLoader) – PyTorch dataloader.

  • output_dir (str) – Target location for saving the output image.

imaginaire.trainers.munit module

class imaginaire.trainers.munit.Trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]

Bases: imaginaire.trainers.base.BaseTrainer

Reimplementation of the MUNIT (https://arxiv.org/abs/1804.04732) algorithm.

Parameters
  • cfg (obj) – Global configuration.

  • net_G (obj) – Generator network.

  • net_D (obj) – Discriminator network.

  • opt_G (obj) – Optimizer for the generator network.

  • opt_D (obj) – Optimizer for the discriminator network.

  • sch_G (obj) – Scheduler for the generator optimizer.

  • sch_D (obj) – Scheduler for the discriminator optimizer.

  • train_data_loader (obj) – Train data loader.

  • val_data_loader (obj) – Validation data loader.

dis_forward(data)[source]

Compute the loss for MUNIT discriminator.

Parameters

data (dict) – Training data at the current iteration.

gen_forward(data)[source]

Compute the loss for MUNIT generator.

Parameters

data (dict) – Training data at the current iteration.

write_metrics()[source]

Compute metrics and save them to tensorboard

imaginaire.trainers.pix2pixHD module

class imaginaire.trainers.pix2pixHD.Trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]

Bases: imaginaire.trainers.spade.Trainer

Initialize pix2pixHD trainer.

Parameters
  • cfg (obj) – Global configuration.

  • net_G (obj) – Generator network.

  • net_D (obj) – Discriminator network.

  • opt_G (obj) – Optimizer for the generator network.

  • opt_D (obj) – Optimizer for the discriminator network.

  • sch_G (obj) – Scheduler for the generator optimizer.

  • sch_D (obj) – Scheduler for the discriminator optimizer.

  • train_data_loader (obj) – Train data loader.

  • val_data_loader (obj) – Validation data loader.

dis_forward(data)[source]

Compute the loss for pix2pixHD discriminator.

Parameters

data (dict) – Training data at the current iteration.

gen_forward(data)[source]

Compute the loss for pix2pixHD generator.

Parameters

data (dict) – Training data at the current iteration.

pre_process(data)[source]

Data pre-processing step for the pix2pixHD method. It takes a dictionary as input where the dictionary contains a label field. The label field is the concatenation of the segmentation mask and the instance map. In this function, we will replace the instance map with an edge map. We will also add a “instance_maps” field to the dictionary.

Parameters
  • data (dict) – Input dictionary.

  • data['label'] – Input label map where the last channel is the instance map.

imaginaire.trainers.spade module

class imaginaire.trainers.spade.Trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]

Bases: imaginaire.trainers.base.BaseTrainer

Initialize SPADE trainer.

Parameters
  • cfg (Config) – Global configuration.

  • net_G (obj) – Generator network.

  • net_D (obj) – Discriminator network.

  • opt_G (obj) – Optimizer for the generator network.

  • opt_D (obj) – Optimizer for the discriminator network.

  • sch_G (obj) – Scheduler for the generator optimizer.

  • sch_D (obj) – Scheduler for the discriminator optimizer.

  • train_data_loader (obj) – Train data loader.

  • val_data_loader (obj) – Validation data loader.

dis_forward(data)[source]

Compute the loss for SPADE discriminator.

Parameters

data (dict) – Training data at the current iteration.

gen_forward(data)[source]

Compute the loss for SPADE generator.

Parameters

data (dict) – Training data at the current iteration.

recalculate_batch_norm_statistics(data_loader)[source]

Update the statistics in the moving average model.

Parameters

data_loader (pytorch data loader) – Data loader for estimating the statistics.

write_metrics()[source]

If moving average model presents, we have two meters one for regular FID and one for average FID. If no moving average model, we just report average FID.

imaginaire.trainers.unit module

class imaginaire.trainers.unit.Trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]

Bases: imaginaire.trainers.base.BaseTrainer

Reimplementation of the UNIT (https://arxiv.org/abs/1703.00848) algorithm.

Parameters
  • cfg (obj) – Global configuration.

  • net_G (obj) – Generator network.

  • net_D (obj) – Discriminator network.

  • opt_G (obj) – Optimizer for the generator network.

  • opt_D (obj) – Optimizer for the discriminator network.

  • sch_G (obj) – Scheduler for the generator optimizer.

  • sch_D (obj) – Scheduler for the discriminator optimizer.

  • train_data_loader (obj) – Train data loader.

  • val_data_loader (obj) – Validation data loader.

dis_forward(data)[source]

Compute the loss for UNIT discriminator.

Parameters

data (dict) – Training data at the current iteration.

gen_forward(data)[source]

Compute the loss for UNIT generator.

Parameters

data (dict) – Training data at the current iteration.

write_metrics()[source]

Compute metrics and save them to tensorboard

imaginaire.trainers.vid2vid module

class imaginaire.trainers.vid2vid.Trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]

Bases: imaginaire.trainers.base.BaseTrainer

Initialize vid2vid trainer.

Parameters
  • cfg (obj) – Global configuration.

  • net_G (obj) – Generator network.

  • net_D (obj) – Discriminator network.

  • opt_G (obj) – Optimizer for the generator network.

  • opt_D (obj) – Optimizer for the discriminator network.

  • sch_G (obj) – Scheduler for the generator optimizer.

  • sch_D (obj) – Scheduler for the discriminator optimizer.

  • train_data_loader (obj) – Train data loader.

  • val_data_loader (obj) – Validation data loader.

compute_gan_losses(net_D_output, dis_update)[source]

Compute GAN loss and feature matching loss.

Parameters
  • net_D_output (dict) – Output of the discriminator.

  • dis_update (bool) – Whether to update discriminator.

create_sequence_output_dir(output_dir, key)[source]

Create output subdir for this sequence.

Parameters
  • output_dir (str) – Root output dir.

  • key (str) – LMDB key which contains sequence name and file name.

Returns

Output subdir for this sequence. seq_name (str): Name of this sequence.

Return type

output_dir (str)

dis_update(data)[source]

The update is already done in gen_update.

Parameters

data (dict) – Training data at the current iteration.

gen_frames(data, use_model_average=False)[source]

Generate a sequence of frames given a sequence of data.

Parameters
  • data (dict) – Training data at the current iteration.

  • use_model_average (bool) – Whether to use model average for update or not.

gen_update(data)[source]

Update the vid2vid generator. We update in the fashion of dis_update (frame 1), gen_update (frame 1), dis_update (frame 2), gen_update (frame 2), … in each iteration.

Parameters

data (dict) – Training data at the current iteration.

get_data_t(data, net_G_output, data_prev, t)[source]

Get data at current time frame given the sequence of data.

Parameters
  • data (dict) – Training data for current iteration.

  • net_G_output (dict) – Output of the generator (for previous frame).

  • data_prev (dict) – Data for previous frame.

  • t (int) – Current time.

get_dis_losses(net_D_output)[source]

Compute discriminator losses.

Parameters

net_D_output (dict) – Output of the discriminator.

get_gen_losses(data_t, net_G_output, net_D_output)[source]

Compute generator losses.

Parameters
  • data_t (dict) – Training data at the current time t.

  • net_G_output (dict) – Output of the generator.

  • net_D_output (dict) – Output of the discriminator.

get_test_output_images(data)[source]

Get the visualization output of test function.

Parameters

data (dict) – Training data at the current iteration.

init_temporal_network()[source]

Initialize temporal training when beginning to train multiple frames. Set the sequence length to $(initial_sequence_length).

post_process(data, net_G_output)[source]

Do any postprocessing of the data / output here.

Parameters
  • data (dict) – Training data at the current iteration.

  • net_G_output (dict) – Output of the generator.

pre_process(data)[source]

Do any data pre-processing here.

Parameters

data (dict) – Data used for the current iteration.

reset()[source]

Reset the trainer (for inference) at the beginning of a sequence.

save_image(path, data)[source]

Save the output images to path. Note when the generate_raw_output is FALSE. Then, first_net_G_output[‘fake_raw_images’] is None and will not be displayed. In model average mode, we will plot the flow visualization twice. :param path: Save path. :type path: str :param data: Training data for current iteration. :type data: dict

test(test_data_loader, root_output_dir, inference_args)[source]

Run inference on all sequences.

Parameters
  • test_data_loader (object) – Test data loader.

  • root_output_dir (str) – Location to dump outputs.

  • inference_args (optional) – Optional args.

test_single(data, output_dir=None, inference_args=None)[source]

The inference function. If output_dir exists, also save the output image. :param data: Training data at the current iteration. :type data: dict :param output_dir: Save image directory. :type output_dir: str :param inference_args: Inference args. :type inference_args: obj

visualize_label(label)[source]

Visualize the input label when saving to image.

Parameters

label (tensor) – Input label tensor.

write_metrics()[source]

If moving average model presents, we have two meters one for regular FID and one for average FID. If no moving average model, we just report average FID.

imaginaire.trainers.wc_vid2vid module

class imaginaire.trainers.wc_vid2vid.Trainer(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]

Bases: imaginaire.trainers.vid2vid.Trainer

Initialize world consistent vid2vid trainer.

Parameters
  • cfg (obj) – Global configuration.

  • net_G (obj) – Generator network.

  • net_D (obj) – Discriminator network.

  • opt_G (obj) – Optimizer for the generator network.

  • opt_D (obj) – Optimizer for the discriminator network.

  • sch_G (obj) – Scheduler for the generator optimizer.

  • sch_D (obj) – Scheduler for the discriminator optimizer.

  • train_data_loader (obj) – Train data loader.

  • val_data_loader (obj) – Validation data loader.

create_sequence_output_dir(output_dir, key)[source]

Create output subdir for this sequence.

Parameters
  • output_dir (str) – Root output dir.

  • key (str) – LMDB key which contains sequence name and file name.

  • Returns

  • output_dir – Output subdir for this sequence.

  • seq_name (str) – Name of this sequence.

gen_frames(data, use_model_average=False)[source]

Generate a sequence of frames given a sequence of data.

Parameters
  • data (dict) – Training data at the current iteration.

  • use_model_average (bool) – Whether to use model average for update or not.

get_data_t(data, net_G_output, data_prev, t)[source]

Get data at current time frame given the sequence of data.

Parameters
  • data (dict) – Training data for current iteration.

  • net_G_output (dict) – Output of the generator (for previous frame).

  • data_prev (dict) – Data for previous frame.

  • t (int) – Current time.

get_test_output_images(data)[source]

Get the visualization output of test function.

Parameters

data (dict) – Training data at the current iteration.

load_checkpoint(cfg, checkpoint_path, resume=None, load_sch=True)[source]

Save network weights, optimizer parameters, scheduler parameters in the checkpoint.

Parameters
  • cfg (obj) – Global configuration.

  • checkpoint_path (str) – Path to the checkpoint.

reset()[source]

Reset the trainer (for inference) at the beginning of a sequence.

save_image(path, data)[source]

Save the output images to path. Note when the generate_raw_output is FALSE. Then, first_net_G_output[‘fake_raw_images’] is None and will not be displayed. In model average mode, we will plot the flow visualization twice.

Parameters
  • path (str) – Save path.

  • data (dict) – Training data for current iteration.

start_of_iteration(data, current_iteration)[source]

Things to do before an iteration.

Parameters
  • data (dict) – Data used for the current iteration.

  • current_iteration (int) – Current iteration number.

test(test_data_loader, root_output_dir, inference_args)[source]

Run inference on all sequences.

Parameters
  • test_data_loader (object) – Test data loader.

  • root_output_dir (str) – Location to dump outputs.

  • inference_args (optional) – Optional args.

test_single(data, output_dir=None, save_fake_only=False)[source]

The inference function. If output_dir exists, also save the output image.

Parameters
  • data (dict) – Training data at the current iteration.

  • output_dir (str) – Save image directory.

  • save_fake_only (bool) – Only save the fake output image.

Module contents