imaginaire.trainers package¶
Submodules¶
imaginaire.trainers.base module¶
-
class
imaginaire.trainers.base.
BaseTrainer
(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]¶ Bases:
object
Base trainer. We expect that all trainers inherit this class.
- Parameters
cfg (obj) – Global configuration.
net_G (obj) – Generator network.
net_D (obj) – Discriminator network.
opt_G (obj) – Optimizer for the generator network.
opt_D (obj) – Optimizer for the discriminator network.
sch_G (obj) – Scheduler for the generator optimizer.
sch_D (obj) – Scheduler for the discriminator optimizer.
train_data_loader (obj) – Train data loader.
val_data_loader (obj) – Validation data loader.
-
dis_update
(data)[source]¶ Update the discriminator.
- Parameters
data (dict) – Data used for the current iteration.
-
end_of_epoch
(data, current_epoch, current_iteration)[source]¶ Things to do after an epoch.
- Parameters
data (dict) – Data used for the current iteration.
current_epoch (int) – Current number of epoch.
current_iteration (int) – Current number of iteration.
-
end_of_iteration
(data, current_epoch, current_iteration)[source]¶ Things to do after an iteration.
- Parameters
data (dict) – Data used for the current iteration.
current_epoch (int) – Current number of epoch.
current_iteration (int) – Current number of iteration.
-
gen_update
(data)[source]¶ Update the generator.
- Parameters
data (dict) – Data used for the current iteration.
-
load_checkpoint
(cfg, checkpoint_path, resume=None, load_sch=True)[source]¶ Load network weights, optimizer parameters, scheduler parameters from a checkpoint.
- Parameters
cfg (obj) – Global configuration.
checkpoint_path (str) – Path to the checkpoint.
resume (bool or None) – If not
None
, will determine whether or not to load optimizers in addition to network weights.
-
pre_process
(data)[source]¶ Custom data pre-processing function. Utilize this function if you need to preprocess your data before sending it to the generator and discriminator.
- Parameters
data (dict) – Data used for the current iteration.
-
recalculate_batch_norm_statistics
(data_loader, averaged=True)[source]¶ Update the statistics in the moving average model.
- Parameters
data_loader (torch.utils.data.DataLoader) – Data loader for estimating the statistics.
averaged (Boolean) – True/False, we recalculate batch norm statistics for EMA/regular
-
save_checkpoint
(current_epoch, current_iteration)[source]¶ Save network weights, optimizer parameters, scheduler parameters to a checkpoint.
-
save_image
(path, data)[source]¶ Compute visualization images and save them to the disk.
- Parameters
path (str) – Location of the file.
data (dict) – Data used for the current iteration.
-
start_of_epoch
(current_epoch)[source]¶ Things to do before an epoch.
- Parameters
current_epoch (int) – Current number of epoch.
-
start_of_iteration
(data, current_iteration)[source]¶ Things to do before an iteration.
- Parameters
data (dict) – Data used for the current iteration.
current_iteration (int) – Current number of iteration.
imaginaire.trainers.fs_vid2vid module¶
-
class
imaginaire.trainers.fs_vid2vid.
Trainer
(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]¶ Bases:
imaginaire.trainers.vid2vid.Trainer
Initialize vid2vid trainer.
- Parameters
cfg (obj) – Global configuration.
net_G (obj) – Generator network.
net_D (obj) – Discriminator network.
opt_G (obj) – Optimizer for the generator network.
opt_D (obj) – Optimizer for the discriminator network.
sch_G (obj) – Scheduler for the generator optimizer.
sch_D (obj) – Scheduler for the discriminator optimizer.
train_data_loader (obj) – Train data loader.
val_data_loader (obj) – Validation data loader.
-
finetune
(data, inference_args)[source]¶ Finetune the model for a few iterations on the inference data.
-
get_data_t
(data, net_G_output, data_prev, t)[source]¶ Get data at current time frame given the sequence of data.
- Parameters
data (dict) – Training data for current iteration.
net_G_output (dict) – Output of the generator (for previous frame).
data_prev (dict) – Data for previous frame.
t (int) – Current time.
-
get_test_output_images
(data)[source]¶ Get the visualization output of test function.
- Parameters
data (dict) – Training data at the current iteration.
-
post_process
(data, net_G_output)[source]¶ Do any postprocessing of the data / output here.
- Parameters
data (dict) – Training data at the current iteration.
net_G_output (dict) – Output of the generator.
-
pre_process
(data)[source]¶ Do any data pre-processing here.
- Parameters
data (dict) – Data used for the current iteration.
-
save_image
(path, data)[source]¶ Save the output images to path. Note when the generate_raw_output is FALSE. Then, first_net_G_output[‘fake_raw_images’] is None and will not be displayed. In model average mode, we will plot the flow visualization twice.
- Parameters
path (str) – Save path.
data (dict) – Training data for current iteration.
imaginaire.trainers.funit module¶
-
class
imaginaire.trainers.funit.
Trainer
(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]¶ Bases:
imaginaire.trainers.base.BaseTrainer
Reimplementation of the FUNIT (https://arxiv.org/abs/1905.01723) algorithm.
- Parameters
cfg (obj) – Global configuration.
net_G (obj) – Generator network.
net_D (obj) – Discriminator network.
opt_G (obj) – Optimizer for the generator network.
opt_D (obj) – Optimizer for the discriminator network.
sch_G (obj) – Scheduler for the generator optimizer.
sch_D (obj) – Scheduler for the discriminator optimizer.
train_data_loader (obj) – Train data loader.
val_data_loader (obj) – Validation data loader.
-
dis_forward
(data)[source]¶ Compute the loss for FUNIT discriminator.
- Parameters
data (dict) – Training data at the current iteration.
imaginaire.trainers.gancraft module¶
-
class
imaginaire.trainers.gancraft.
GauGANLoader
(gaugan_cfg)[source]¶ Bases:
object
Manages the SPADE/GauGAN model used to generate pseudo-GTs for training GANcraft.
- Parameters
gaugan_cfg (Config) – SPADE configuration.
-
class
imaginaire.trainers.gancraft.
Trainer
(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]¶ Bases:
imaginaire.trainers.base.BaseTrainer
Initialize GANcraft trainer.
- Parameters
cfg (Config) – Global configuration.
net_G (obj) – Generator network.
net_D (obj) – Discriminator network.
opt_G (obj) – Optimizer for the generator network.
opt_D (obj) – Optimizer for the discriminator network.
sch_G (obj) – Scheduler for the generator optimizer.
sch_D (obj) – Scheduler for the discriminator optimizer.
train_data_loader (obj) – Train data loader.
val_data_loader (obj) – Validation data loader.
-
dis_forward
(data)[source]¶ Compute the loss for GANcraft discriminator.
- Parameters
data (dict) – Training data at the current iteration.
-
gen_forward
(data)[source]¶ Compute the loss for SPADE generator.
- Parameters
data (dict) – Training data at the current iteration.
-
load_checkpoint
(cfg, checkpoint_path, resume=None, load_sch=True)[source]¶ Load network weights, optimizer parameters, scheduler parameters from a checkpoint.
- Parameters
cfg (obj) – Global configuration.
checkpoint_path (str) – Path to the checkpoint.
resume (bool or None) – If not
None
, will determine whether orto load optimizers in addition to network weights. (not) –
imaginaire.trainers.munit module¶
-
class
imaginaire.trainers.munit.
Trainer
(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]¶ Bases:
imaginaire.trainers.base.BaseTrainer
Reimplementation of the MUNIT (https://arxiv.org/abs/1804.04732) algorithm.
- Parameters
cfg (obj) – Global configuration.
net_G (obj) – Generator network.
net_D (obj) – Discriminator network.
opt_G (obj) – Optimizer for the generator network.
opt_D (obj) – Optimizer for the discriminator network.
sch_G (obj) – Scheduler for the generator optimizer.
sch_D (obj) – Scheduler for the discriminator optimizer.
train_data_loader (obj) – Train data loader.
val_data_loader (obj) – Validation data loader.
-
dis_forward
(data)[source]¶ Compute the loss for MUNIT discriminator.
- Parameters
data (dict) – Training data at the current iteration.
imaginaire.trainers.pix2pixHD module¶
-
class
imaginaire.trainers.pix2pixHD.
Trainer
(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]¶ Bases:
imaginaire.trainers.spade.Trainer
Initialize pix2pixHD trainer.
- Parameters
cfg (obj) – Global configuration.
net_G (obj) – Generator network.
net_D (obj) – Discriminator network.
opt_G (obj) – Optimizer for the generator network.
opt_D (obj) – Optimizer for the discriminator network.
sch_G (obj) – Scheduler for the generator optimizer.
sch_D (obj) – Scheduler for the discriminator optimizer.
train_data_loader (obj) – Train data loader.
val_data_loader (obj) – Validation data loader.
-
dis_forward
(data)[source]¶ Compute the loss for pix2pixHD discriminator.
- Parameters
data (dict) – Training data at the current iteration.
-
gen_forward
(data)[source]¶ Compute the loss for pix2pixHD generator.
- Parameters
data (dict) – Training data at the current iteration.
-
pre_process
(data)[source]¶ Data pre-processing step for the pix2pixHD method. It takes a dictionary as input where the dictionary contains a label field. The label field is the concatenation of the segmentation mask and the instance map. In this function, we will replace the instance map with an edge map. We will also add a “instance_maps” field to the dictionary.
- Parameters
data (dict) – Input dictionary.
data['label'] – Input label map where the last channel is the instance map.
imaginaire.trainers.spade module¶
-
class
imaginaire.trainers.spade.
Trainer
(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]¶ Bases:
imaginaire.trainers.base.BaseTrainer
Initialize SPADE trainer.
- Parameters
cfg (Config) – Global configuration.
net_G (obj) – Generator network.
net_D (obj) – Discriminator network.
opt_G (obj) – Optimizer for the generator network.
opt_D (obj) – Optimizer for the discriminator network.
sch_G (obj) – Scheduler for the generator optimizer.
sch_D (obj) – Scheduler for the discriminator optimizer.
train_data_loader (obj) – Train data loader.
val_data_loader (obj) – Validation data loader.
-
dis_forward
(data)[source]¶ Compute the loss for SPADE discriminator.
- Parameters
data (dict) – Training data at the current iteration.
-
gen_forward
(data)[source]¶ Compute the loss for SPADE generator.
- Parameters
data (dict) – Training data at the current iteration.
imaginaire.trainers.unit module¶
-
class
imaginaire.trainers.unit.
Trainer
(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]¶ Bases:
imaginaire.trainers.base.BaseTrainer
Reimplementation of the UNIT (https://arxiv.org/abs/1703.00848) algorithm.
- Parameters
cfg (obj) – Global configuration.
net_G (obj) – Generator network.
net_D (obj) – Discriminator network.
opt_G (obj) – Optimizer for the generator network.
opt_D (obj) – Optimizer for the discriminator network.
sch_G (obj) – Scheduler for the generator optimizer.
sch_D (obj) – Scheduler for the discriminator optimizer.
train_data_loader (obj) – Train data loader.
val_data_loader (obj) – Validation data loader.
-
dis_forward
(data)[source]¶ Compute the loss for UNIT discriminator.
- Parameters
data (dict) – Training data at the current iteration.
imaginaire.trainers.vid2vid module¶
-
class
imaginaire.trainers.vid2vid.
Trainer
(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]¶ Bases:
imaginaire.trainers.base.BaseTrainer
Initialize vid2vid trainer.
- Parameters
cfg (obj) – Global configuration.
net_G (obj) – Generator network.
net_D (obj) – Discriminator network.
opt_G (obj) – Optimizer for the generator network.
opt_D (obj) – Optimizer for the discriminator network.
sch_G (obj) – Scheduler for the generator optimizer.
sch_D (obj) – Scheduler for the discriminator optimizer.
train_data_loader (obj) – Train data loader.
val_data_loader (obj) – Validation data loader.
-
compute_gan_losses
(net_D_output, dis_update)[source]¶ Compute GAN loss and feature matching loss.
- Parameters
net_D_output (dict) – Output of the discriminator.
dis_update (bool) – Whether to update discriminator.
-
create_sequence_output_dir
(output_dir, key)[source]¶ Create output subdir for this sequence.
- Parameters
output_dir (str) – Root output dir.
key (str) – LMDB key which contains sequence name and file name.
- Returns
Output subdir for this sequence. seq_name (str): Name of this sequence.
- Return type
output_dir (str)
-
dis_update
(data)[source]¶ The update is already done in gen_update.
- Parameters
data (dict) – Training data at the current iteration.
-
gen_frames
(data, use_model_average=False)[source]¶ Generate a sequence of frames given a sequence of data.
- Parameters
data (dict) – Training data at the current iteration.
use_model_average (bool) – Whether to use model average for update or not.
-
gen_update
(data)[source]¶ Update the vid2vid generator. We update in the fashion of dis_update (frame 1), gen_update (frame 1), dis_update (frame 2), gen_update (frame 2), … in each iteration.
- Parameters
data (dict) – Training data at the current iteration.
-
get_data_t
(data, net_G_output, data_prev, t)[source]¶ Get data at current time frame given the sequence of data.
- Parameters
data (dict) – Training data for current iteration.
net_G_output (dict) – Output of the generator (for previous frame).
data_prev (dict) – Data for previous frame.
t (int) – Current time.
-
get_dis_losses
(net_D_output)[source]¶ Compute discriminator losses.
- Parameters
net_D_output (dict) – Output of the discriminator.
-
get_gen_losses
(data_t, net_G_output, net_D_output)[source]¶ Compute generator losses.
- Parameters
data_t (dict) – Training data at the current time t.
net_G_output (dict) – Output of the generator.
net_D_output (dict) – Output of the discriminator.
-
get_test_output_images
(data)[source]¶ Get the visualization output of test function.
- Parameters
data (dict) – Training data at the current iteration.
-
init_temporal_network
()[source]¶ Initialize temporal training when beginning to train multiple frames. Set the sequence length to $(initial_sequence_length).
-
post_process
(data, net_G_output)[source]¶ Do any postprocessing of the data / output here.
- Parameters
data (dict) – Training data at the current iteration.
net_G_output (dict) – Output of the generator.
-
pre_process
(data)[source]¶ Do any data pre-processing here.
- Parameters
data (dict) – Data used for the current iteration.
-
save_image
(path, data)[source]¶ Save the output images to path. Note when the generate_raw_output is FALSE. Then, first_net_G_output[‘fake_raw_images’] is None and will not be displayed. In model average mode, we will plot the flow visualization twice. :param path: Save path. :type path: str :param data: Training data for current iteration. :type data: dict
-
test
(test_data_loader, root_output_dir, inference_args)[source]¶ Run inference on all sequences.
- Parameters
test_data_loader (object) – Test data loader.
root_output_dir (str) – Location to dump outputs.
inference_args (optional) – Optional args.
-
test_single
(data, output_dir=None, inference_args=None)[source]¶ The inference function. If output_dir exists, also save the output image. :param data: Training data at the current iteration. :type data: dict :param output_dir: Save image directory. :type output_dir: str :param inference_args: Inference args. :type inference_args: obj
imaginaire.trainers.wc_vid2vid module¶
-
class
imaginaire.trainers.wc_vid2vid.
Trainer
(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader)[source]¶ Bases:
imaginaire.trainers.vid2vid.Trainer
Initialize world consistent vid2vid trainer.
- Parameters
cfg (obj) – Global configuration.
net_G (obj) – Generator network.
net_D (obj) – Discriminator network.
opt_G (obj) – Optimizer for the generator network.
opt_D (obj) – Optimizer for the discriminator network.
sch_G (obj) – Scheduler for the generator optimizer.
sch_D (obj) – Scheduler for the discriminator optimizer.
train_data_loader (obj) – Train data loader.
val_data_loader (obj) – Validation data loader.
-
create_sequence_output_dir
(output_dir, key)[source]¶ Create output subdir for this sequence.
- Parameters
output_dir (str) – Root output dir.
key (str) – LMDB key which contains sequence name and file name.
Returns –
output_dir – Output subdir for this sequence.
seq_name (str) – Name of this sequence.
-
gen_frames
(data, use_model_average=False)[source]¶ Generate a sequence of frames given a sequence of data.
- Parameters
data (dict) – Training data at the current iteration.
use_model_average (bool) – Whether to use model average for update or not.
-
get_data_t
(data, net_G_output, data_prev, t)[source]¶ Get data at current time frame given the sequence of data.
- Parameters
data (dict) – Training data for current iteration.
net_G_output (dict) – Output of the generator (for previous frame).
data_prev (dict) – Data for previous frame.
t (int) – Current time.
-
get_test_output_images
(data)[source]¶ Get the visualization output of test function.
- Parameters
data (dict) – Training data at the current iteration.
-
load_checkpoint
(cfg, checkpoint_path, resume=None, load_sch=True)[source]¶ Save network weights, optimizer parameters, scheduler parameters in the checkpoint.
- Parameters
cfg (obj) – Global configuration.
checkpoint_path (str) – Path to the checkpoint.
-
save_image
(path, data)[source]¶ Save the output images to path. Note when the generate_raw_output is FALSE. Then, first_net_G_output[‘fake_raw_images’] is None and will not be displayed. In model average mode, we will plot the flow visualization twice.
- Parameters
path (str) – Save path.
data (dict) – Training data for current iteration.
-
start_of_iteration
(data, current_iteration)[source]¶ Things to do before an iteration.
- Parameters
data (dict) – Data used for the current iteration.
current_iteration (int) – Current iteration number.
-
test
(test_data_loader, root_output_dir, inference_args)[source]¶ Run inference on all sequences.
- Parameters
test_data_loader (object) – Test data loader.
root_output_dir (str) – Location to dump outputs.
inference_args (optional) – Optional args.
-
test_single
(data, output_dir=None, save_fake_only=False)[source]¶ The inference function. If output_dir exists, also save the output image.
- Parameters
data (dict) – Training data at the current iteration.
output_dir (str) – Save image directory.
save_fake_only (bool) – Only save the fake output image.