imaginaire.discriminators package

Submodules

imaginaire.discriminators.dummy module

class imaginaire.discriminators.dummy.Discriminator(dis_cfg, data_cfg)[source]

Bases: torch.nn.modules.module.Module

Dummy Discriminator constructor.

Parameters
  • dis_cfg (obj) – Discriminator definition part of the yaml config file.

  • data_cfg (obj) – Data definition part of the yaml config file

forward(data)[source]

Dummy discriminator forward.

Parameters

data (dict) –

training = None

imaginaire.discriminators.fpse module

class imaginaire.discriminators.fpse.FPSEDiscriminator(num_input_channels, num_labels, num_filters, kernel_size, weight_norm_type, activation_norm_type)[source]

Bases: torch.nn.modules.module.Module

# Feature-Pyramid Semantics Embedding Discriminator. This is a copy of the discriminator in https://arxiv.org/pdf/1910.06809.pdf

forward(images, segmaps)[source]
Parameters
  • images – image tensors.

  • segmaps – segmentation map tensors.

training = None

imaginaire.discriminators.fs_vid2vid module

class imaginaire.discriminators.fs_vid2vid.Discriminator(dis_cfg, data_cfg)[source]

Bases: torch.nn.modules.module.Module

Image and video discriminator constructor.

Parameters
  • dis_cfg (obj) – Discriminator part of the yaml config file.

  • data_cfg (obj) – Data definition part of the yaml config file

discriminate_video(real_B, fake_B, scale)[source]

Discriminate a sequence of images.

Parameters
  • real_B (NxCxHxW tensor) – Real image.

  • fake_B (NxCxHxW tensor) – Fake image.

  • scale (int) – Temporal scale.

Returns

  • pred_real (NxC2xH2xW2 tensor): Output of net_D for real images.

  • pred_fake (NxC2xH2xW2 tensor): Output of net_D for fake images.

Return type

(tuple)

discrminate_image(net_D, real_A, real_B, fake_B)[source]

Discriminate individual images.

Parameters
  • net_D (obj) – Discriminator network.

  • real_A (NxC1xHxW tensor) – Input label map.

  • real_B (NxC2xHxW tensor) – Real image.

  • fake_B (NxC2xHxW tensor) – Fake image.

Returns

  • pred_real (NxC3xH2xW2 tensor): Output of net_D for real images.

  • pred_fake (NxC3xH2xW2 tensor): Output of net_D for fake images.

Return type

(tuple)

forward(data, net_G_output, past_frames)[source]

Discriminator forward.

Parameters
  • data (dict) – Input data.

  • net_G_output (dict) – Generator output.

  • past_frames (list of tensors) – Past real frames / generator outputs.

Returns

  • output (dict): Discriminator output.

  • past_frames (list of tensors): New past frames by adding current outputs.

Return type

(tuple)

training = None
class imaginaire.discriminators.fs_vid2vid.MultiPatchDiscriminator(dis_cfg, num_input_channels)[source]

Bases: torch.nn.modules.module.Module

Multi-resolution patch discriminator.

Parameters
  • dis_cfg (obj) – Discriminator part of the yaml config file.

  • num_input_channels (int) – Number of input channels.

forward(input_x)[source]

Multi-resolution patch discriminator forward.

Parameters

input_x (N x C x H x W tensor) – Concatenation of images and semantic representations.

Returns

  • output (list): list of output tensors produced by individual patch discriminators.

  • features (list): list of lists of features produced by individual patch discriminators.

Return type

(dict)

training = None
imaginaire.discriminators.fs_vid2vid.get_all_skipped_frames(past_frames, new_frames, t_scales, tD)[source]

Get temporally skipped frames from the input frames.

Parameters
  • past_frames (list of tensors) – Past real frames / generator outputs.

  • new_frames (list of tensors) – Current real frame / generated output.

  • t_scales (int) – Temporal scale.

  • tD (int) – Number of frames as input to the temporal discriminator.

Returns

  • new_past_frames (list of tensors): Past + current frames.

  • skipped_frames (list of tensors): Temporally skipped frames using the given t_scales.

Return type

(tuple)

imaginaire.discriminators.fs_vid2vid.get_skipped_frames(all_frames, frame, t_scales, tD)[source]

Get temporally skipped frames from the input frames.

Parameters
  • all_frames (NxTxCxHxW tensor) – All past frames.

  • frame (Nx1xCxHxW tensor) – Current frame.

  • t_scales (int) – Temporal scale.

  • tD (int) – Number of frames as input to the temporal discriminator.

Returns

  • all_frames (NxTxCxHxW tensor): Past + current frames.

  • skipped_frames (list of NxTxCxHxW tensors): Temporally skipped frames.

Return type

(tuple)

imaginaire.discriminators.funit module

class imaginaire.discriminators.funit.Discriminator(dis_cfg, data_cfg)[source]

Bases: torch.nn.modules.module.Module

Discriminator in the improved FUNIT baseline in the COCO-FUNIT paper.

Parameters
  • dis_cfg (obj) – Discriminator definition part of the yaml config file.

  • data_cfg (obj) – Data definition part of the yaml config file.

forward(data, net_G_output, recon=True)[source]

Improved FUNIT discriminator forward function.

Parameters
  • data (dict) – Training data at the current iteration.

  • net_G_output (dict) – Fake data generated at the current iteration.

  • recon (bool) – If True, also classifies reconstructed images.

training = None
class imaginaire.discriminators.funit.ResDiscriminator(image_channels=3, num_classes=119, num_filters=64, max_num_filters=1024, num_layers=6, padding_mode='reflect', weight_norm_type='', **kwargs)[source]

Bases: torch.nn.modules.module.Module

Residual discriminator architecture used in the FUNIT paper.

forward(images, labels=None)[source]

Forward function of the projection discriminator.

Parameters
  • images (image tensor) – Images inputted to the discriminator.

  • labels (long int tensor) – Class labels of the images.

training = None

imaginaire.discriminators.gancraft module

class imaginaire.discriminators.gancraft.Discriminator(dis_cfg, data_cfg)[source]

Bases: torch.nn.modules.module.Module

Multi-resolution patch discriminator. Based on FPSE discriminator but with N+1 labels.

Parameters
  • dis_cfg (obj) – Discriminator definition part of the yaml config file.

  • data_cfg (obj) – Data definition part of the yaml config file.

forward(data, net_G_output, weights=None, incl_real=False, incl_pseudo_real=False)[source]

GANcraft discriminator forward.

Parameters
  • data (dict) –

    • data (N x C1 x H x W tensor) : Ground truth images.

    • label (N x C2 x H x W tensor) : Semantic representations.

    • z (N x style_dims tensor): Gaussian random noise.

  • net_G_output (dict) –

    • fake_images (N x C1 x H x W tensor) : Fake images.

Returns

  • real_outputs (list): list of output tensors produced by individual patch discriminators for real images.

  • real_features (list): list of lists of features produced by individual patch discriminators for real images.

  • fake_outputs (list): list of output tensors produced by individual patch discriminators for fake images.

  • fake_features (list): list of lists of features produced by individual patch discriminators for fake images.

Return type

output_x (dict)

training = None
class imaginaire.discriminators.gancraft.FPSEDiscriminator(num_input_channels, num_labels, num_filters, kernel_size, weight_norm_type, activation_norm_type, do_multiscale, smooth_resample, no_label_except_largest_scale)[source]

Bases: torch.nn.modules.module.Module

forward(images, segmaps, weights=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

static smooth_interp(x, size)[source]

Smooth interpolation of segmentation maps.

Parameters
  • x (4D tensor) – Segmentation maps.

  • size (2D list) – Target size (H, W).

training = None

imaginaire.discriminators.mlp_multiclass module

class imaginaire.discriminators.mlp_multiclass.Discriminator(dis_cfg, data_cfg)[source]

Bases: torch.nn.modules.module.Module

Multi-layer Perceptron Classifier constructor.

Parameters
  • dis_cfg (obj) – Discriminator definition part of the yaml config file.

  • data_cfg (obj) – Data definition part of the yaml config file

forward(data)[source]

Patch Discriminator forward.

Parameters

data (dict) –

  • data (N x -1 tensor): We will reshape the tensor to this format.

Returns

  • results (N x C tensor): Output scores before softmax.

Return type

(dict)

training = None

imaginaire.discriminators.multires_patch module

class imaginaire.discriminators.multires_patch.Discriminator(dis_cfg, data_cfg)[source]

Bases: torch.nn.modules.module.Module

Multi-resolution patch discriminator.

Parameters
  • dis_cfg (obj) – Discriminator definition part of the yaml config file.

  • data_cfg (obj) – Data definition part of the yaml config file.

forward(data, net_G_output, real=True)[source]

SPADE Generator forward.

Parameters
  • data (dict) –

    • data (N x C1 x H x W tensor) : Ground truth images.

    • label (N x C2 x H x W tensor) : Semantic representations.

    • z (N x style_dims tensor): Gaussian random noise.

  • net_G_output (dict) – fake_images (N x C1 x H x W tensor) : Fake images.

  • real (bool) – If True, also classifies real images. Otherwise it only classifies generated images to save computation during the generator update.

Returns

  • real_outputs (list): list of output tensors produced by

  • individual patch discriminators for real images.

  • real_features (list): list of lists of features produced by individual patch discriminators for real images.

  • fake_outputs (list): list of output tensors produced by individual patch discriminators for fake images.

  • fake_features (list): list of lists of features produced by individual patch discriminators for fake images.

Return type

(tuple)

training = None
class imaginaire.discriminators.multires_patch.MultiResPatchDiscriminator(num_discriminators=3, kernel_size=3, num_image_channels=3, num_filters=64, num_layers=4, max_num_filters=512, activation_norm_type='', weight_norm_type='', **kwargs)[source]

Bases: torch.nn.modules.module.Module

Multi-resolution patch discriminator.

Parameters
  • num_discriminators (int) – Num. of discriminators (one per scale).

  • kernel_size (int) – Convolution kernel size.

  • num_image_channels (int) – Num. of channels in the real/fake image.

  • num_filters (int) – Num. of base filters in a layer.

  • num_layers (int) – Num. of layers for the patch discriminator.

  • max_num_filters (int) – Maximum num. of filters in a layer.

  • activation_norm_type (str) – batch_norm/instance_norm/none/….

  • weight_norm_type (str) – none/spectral_norm/weight_norm

forward(input_x)[source]

Multi-resolution patch discriminator forward.

Parameters

input_x (tensor) – Input images.

Returns

  • output_list (list): list of output tensors produced by individual patch discriminators.

  • features_list (list): list of lists of features produced by individual patch discriminators.

  • input_list (list): list of downsampled input images.

Return type

(tuple)

training = None
class imaginaire.discriminators.multires_patch.NLayerPatchDiscriminator(kernel_size, num_input_channels, num_filters, num_layers, max_num_filters, activation_norm_type, weight_norm_type)[source]

Bases: torch.nn.modules.module.Module

Patch Discriminator constructor.

Parameters
  • kernel_size (int) – Convolution kernel size.

  • num_input_channels (int) – Num. of channels in the real/fake image.

  • num_filters (int) – Num. of base filters in a layer.

  • num_layers (int) – Num. of layers for the patch discriminator.

  • max_num_filters (int) – Maximum num. of filters in a layer.

  • activation_norm_type (str) – batch_norm/instance_norm/none/….

  • weight_norm_type (str) – none/spectral_norm/weight_norm

forward(input_x)[source]

Patch Discriminator forward.

Parameters

input_x (N x C x H1 x W2 tensor) – Concatenation of images and semantic representations.

Returns

  • output (N x 1 x H2 x W2 tensor): Discriminator output value. Before the sigmoid when using NSGAN.

  • features (list): lists of tensors of the intermediate activations.

Return type

(tuple)

training = None
class imaginaire.discriminators.multires_patch.WeightSharedMultiResPatchDiscriminator(num_discriminators=3, kernel_size=3, num_image_channels=3, num_filters=64, num_layers=4, max_num_filters=512, activation_norm_type='', weight_norm_type='', **kwargs)[source]

Bases: torch.nn.modules.module.Module

Multi-resolution patch discriminator with shared weights.

Parameters
  • num_discriminators (int) – Num. of discriminators (one per scale).

  • kernel_size (int) – Convolution kernel size.

  • num_image_channels (int) – Num. of channels in the real/fake image.

  • num_filters (int) – Num. of base filters in a layer.

  • num_layers (int) – Num. of layers for the patch discriminator.

  • max_num_filters (int) – Maximum num. of filters in a layer.

  • activation_norm_type (str) – batch_norm/instance_norm/none/….

  • weight_norm_type (str) – none/spectral_norm/weight_norm

forward(input_x)[source]

Multi-resolution patch discriminator forward.

Parameters

input_x (tensor) – Input images.

Returns

  • output_list (list): list of output tensors produced by individual patch discriminators.

  • features_list (list): list of lists of features produced by individual patch discriminators.

  • input_list (list): list of downsampled input images.

Return type

(tuple)

training = None

imaginaire.discriminators.munit module

class imaginaire.discriminators.munit.Discriminator(dis_cfg, data_cfg)[source]

Bases: torch.nn.modules.module.Module

MUNIT discriminator. It can be either a multi-resolution patch discriminator like in the original implementation, or a global residual discriminator.

Parameters
  • dis_cfg (obj) – Discriminator definition part of the yaml config file.

  • data_cfg (obj) – Data definition part of the yaml config file

forward(data, net_G_output, gan_recon=False, real=True)[source]

Returns the output of the discriminator.

Parameters
  • data (dict) –

    • images_a (tensor) : Images in domain A.

    • images_b (tensor) : Images in domain B.

  • net_G_output (dict) –

    • images_ab (tensor) : Images translated from domain A to B by the generator.

    • images_ba (tensor) : Images translated from domain B to A by the generator.

    • images_aa (tensor) : Reconstructed images in domain A.

    • images_bb (tensor) : Reconstructed images in domain B.

  • gan_recon (bool) – If True, also classifies reconstructed images.

  • real (bool) – If True, also classifies real images. Otherwise it only classifies generated images to save computation during the generator update.

Returns

  • out_ab (tensor): Output of the discriminator for images translated from domain A to B by the generator.

  • out_ab (tensor): Output of the discriminator for images translated from domain B to A by the generator.

  • fea_ab (tensor): Intermediate features of the discriminator for images translated from domain B to A by the generator.

  • fea_ba (tensor): Intermediate features of the discriminator for images translated from domain A to B by the generator.

  • out_a (tensor): Output of the discriminator for images in domain A.

  • out_b (tensor): Output of the discriminator for images in domain B.

  • fea_a (tensor): Intermediate features of the discriminator for images in domain A.

  • fea_b (tensor): Intermediate features of the discriminator for images in domain B.

  • out_aa (tensor): Output of the discriminator for reconstructed images in domain A.

  • out_bb (tensor): Output of the discriminator for reconstructed images in domain B.

  • fea_aa (tensor): Intermediate features of the discriminator for reconstructed images in domain A.

  • fea_bb (tensor): Intermediate features of the discriminator for reconstructed images in domain B.

Return type

(dict)

training = None

imaginaire.discriminators.residual module

class imaginaire.discriminators.residual.ResDiscriminator(image_channels=3, num_filters=64, max_num_filters=512, first_kernel_size=1, num_layers=4, padding_mode='zeros', activation_norm_type='', weight_norm_type='', aggregation='conv', order='pre_act', anti_aliased=False, **kwargs)[source]

Bases: torch.nn.modules.module.Module

Global residual discriminator.

Parameters
  • image_channels (int) – Num. of channels in the real/fake image.

  • num_filters (int) – Num. of base filters in a layer.

  • max_num_filters (int) – Maximum num. of filters in a layer.

  • first_kernel_size (int) – Kernel size in the first layer.

  • num_layers (int) – Num. of layers in discriminator.

  • padding_mode (str) – Padding mode.

  • activation_norm_type (str) – Type of activation normalization. 'none', 'instance', 'batch', 'sync_batch'.

  • weight_norm_type (str) – Type of weight normalization. 'none', 'spectral', or 'weight'.

  • aggregation (str) – Method to aggregate features across different locations in the final layer. 'conv', or 'pool'.

  • order (str) – Order of operations in the residual link.

  • anti_aliased (bool) – If True, uses anti-aliased pooling.

forward(images)[source]

Multi-resolution patch discriminator forward.

Parameters

images (tensor) – Input images.

Returns

  • outputs (tensor): Output of the discriminator.

  • features (tensor): Intermediate features of the discriminator.

  • images (tensor): Input images.

Return type

(tuple)

training = None

imaginaire.discriminators.spade module

class imaginaire.discriminators.spade.Discriminator(dis_cfg, data_cfg)[source]

Bases: torch.nn.modules.module.Module

Multi-resolution patch discriminator.

Parameters
  • dis_cfg (obj) – Discriminator definition part of the yaml config file.

  • data_cfg (obj) – Data definition part of the yaml config file.

forward(data, net_G_output)[source]

SPADE discriminator forward.

Parameters
  • data (dict) –

    • data (N x C1 x H x W tensor) : Ground truth images.

    • label (N x C2 x H x W tensor) : Semantic representations.

    • z (N x style_dims tensor): Gaussian random noise.

  • net_G_output (dict) – fake_images (N x C1 x H x W tensor) : Fake images.

Returns

  • real_outputs (list): list of output tensors produced by individual patch discriminators for real images.

  • real_features (list): list of lists of features produced by individual patch discriminators for real images.

  • fake_outputs (list): list of output tensors produced by individual patch discriminators for fake images.

  • fake_features (list): list of lists of features produced by individual patch discriminators for fake images.

Return type

(dict)

training = None

imaginaire.discriminators.unit module

class imaginaire.discriminators.unit.Discriminator(dis_cfg, data_cfg)[source]

Bases: torch.nn.modules.module.Module

UNIT discriminator. It can be either a multi-resolution patch discriminator like in the original implementation, or a global residual discriminator.

Parameters
  • dis_cfg (obj) – Discriminator definition part of the yaml config file.

  • data_cfg (obj) – Data definition part of the yaml config file

forward(data, net_G_output, gan_recon=False, real=True)[source]

Returns the output of the discriminator.

Parameters
  • data (dict) –

    • images_a (tensor) : Images in domain A.

    • images_b (tensor) : Images in domain B.

  • net_G_output (dict) –

    • images_ab (tensor) : Images translated from domain A to B by the generator.

    • images_ba (tensor) : Images translated from domain B to A by the generator.

    • images_aa (tensor) : Reconstructed images in domain A.

    • images_bb (tensor) : Reconstructed images in domain B.

  • gan_recon (bool) – If True, also classifies reconstructed images.

  • real (bool) – If True, also classifies real images. Otherwise it only classifies generated images to save computation during the generator update.

Returns

  • out_ab (tensor): Output of the discriminator for images translated from domain A to B by the generator.

  • out_ab (tensor): Output of the discriminator for images translated from domain B to A by the generator.

  • fea_ab (tensor): Intermediate features of the discriminator for images translated from domain B to A by the generator.

  • fea_ba (tensor): Intermediate features of the discriminator for images translated from domain A to B by the generator.

  • out_a (tensor): Output of the discriminator for images in domain A.

  • out_b (tensor): Output of the discriminator for images in domain B.

  • fea_a (tensor): Intermediate features of the discriminator for images in domain A.

  • fea_b (tensor): Intermediate features of the discriminator for images in domain B.

  • out_aa (tensor): Output of the discriminator for reconstructed images in domain A.

  • out_bb (tensor): Output of the discriminator for reconstructed images in domain B.

  • fea_aa (tensor): Intermediate features of the discriminator for reconstructed images in domain A.

  • fea_bb (tensor): Intermediate features of the discriminator for reconstructed images in domain B.

Return type

(dict)

training = None

Module contents