imaginaire.discriminators package¶
Submodules¶
imaginaire.discriminators.dummy module¶
-
class
imaginaire.discriminators.dummy.
Discriminator
(dis_cfg, data_cfg)[source]¶ Bases:
torch.nn.modules.module.Module
Dummy Discriminator constructor.
- Parameters
dis_cfg (obj) – Discriminator definition part of the yaml config file.
data_cfg (obj) – Data definition part of the yaml config file
-
training
= None¶
imaginaire.discriminators.fpse module¶
-
class
imaginaire.discriminators.fpse.
FPSEDiscriminator
(num_input_channels, num_labels, num_filters, kernel_size, weight_norm_type, activation_norm_type)[source]¶ Bases:
torch.nn.modules.module.Module
# Feature-Pyramid Semantics Embedding Discriminator. This is a copy of the discriminator in https://arxiv.org/pdf/1910.06809.pdf
-
forward
(images, segmaps)[source]¶ - Parameters
images – image tensors.
segmaps – segmentation map tensors.
-
training
= None¶
-
imaginaire.discriminators.fs_vid2vid module¶
-
class
imaginaire.discriminators.fs_vid2vid.
Discriminator
(dis_cfg, data_cfg)[source]¶ Bases:
torch.nn.modules.module.Module
Image and video discriminator constructor.
- Parameters
dis_cfg (obj) – Discriminator part of the yaml config file.
data_cfg (obj) – Data definition part of the yaml config file
-
discriminate_video
(real_B, fake_B, scale)[source]¶ Discriminate a sequence of images.
- Parameters
real_B (NxCxHxW tensor) – Real image.
fake_B (NxCxHxW tensor) – Fake image.
scale (int) – Temporal scale.
- Returns
pred_real (NxC2xH2xW2 tensor): Output of net_D for real images.
pred_fake (NxC2xH2xW2 tensor): Output of net_D for fake images.
- Return type
(tuple)
-
discrminate_image
(net_D, real_A, real_B, fake_B)[source]¶ Discriminate individual images.
- Parameters
net_D (obj) – Discriminator network.
real_A (NxC1xHxW tensor) – Input label map.
real_B (NxC2xHxW tensor) – Real image.
fake_B (NxC2xHxW tensor) – Fake image.
- Returns
pred_real (NxC3xH2xW2 tensor): Output of net_D for real images.
pred_fake (NxC3xH2xW2 tensor): Output of net_D for fake images.
- Return type
(tuple)
-
forward
(data, net_G_output, past_frames)[source]¶ Discriminator forward.
- Parameters
data (dict) – Input data.
net_G_output (dict) – Generator output.
past_frames (list of tensors) – Past real frames / generator outputs.
- Returns
output (dict): Discriminator output.
past_frames (list of tensors): New past frames by adding current outputs.
- Return type
(tuple)
-
training
= None¶
-
class
imaginaire.discriminators.fs_vid2vid.
MultiPatchDiscriminator
(dis_cfg, num_input_channels)[source]¶ Bases:
torch.nn.modules.module.Module
Multi-resolution patch discriminator.
- Parameters
dis_cfg (obj) – Discriminator part of the yaml config file.
num_input_channels (int) – Number of input channels.
-
forward
(input_x)[source]¶ Multi-resolution patch discriminator forward.
- Parameters
input_x (N x C x H x W tensor) – Concatenation of images and semantic representations.
- Returns
output (list): list of output tensors produced by individual patch discriminators.
features (list): list of lists of features produced by individual patch discriminators.
- Return type
(dict)
-
training
= None¶
-
imaginaire.discriminators.fs_vid2vid.
get_all_skipped_frames
(past_frames, new_frames, t_scales, tD)[source]¶ Get temporally skipped frames from the input frames.
- Parameters
past_frames (list of tensors) – Past real frames / generator outputs.
new_frames (list of tensors) – Current real frame / generated output.
t_scales (int) – Temporal scale.
tD (int) – Number of frames as input to the temporal discriminator.
- Returns
new_past_frames (list of tensors): Past + current frames.
skipped_frames (list of tensors): Temporally skipped frames using the given t_scales.
- Return type
(tuple)
-
imaginaire.discriminators.fs_vid2vid.
get_skipped_frames
(all_frames, frame, t_scales, tD)[source]¶ Get temporally skipped frames from the input frames.
- Parameters
all_frames (NxTxCxHxW tensor) – All past frames.
frame (Nx1xCxHxW tensor) – Current frame.
t_scales (int) – Temporal scale.
tD (int) – Number of frames as input to the temporal discriminator.
- Returns
all_frames (NxTxCxHxW tensor): Past + current frames.
skipped_frames (list of NxTxCxHxW tensors): Temporally skipped frames.
- Return type
(tuple)
imaginaire.discriminators.funit module¶
-
class
imaginaire.discriminators.funit.
Discriminator
(dis_cfg, data_cfg)[source]¶ Bases:
torch.nn.modules.module.Module
Discriminator in the improved FUNIT baseline in the COCO-FUNIT paper.
- Parameters
dis_cfg (obj) – Discriminator definition part of the yaml config file.
data_cfg (obj) – Data definition part of the yaml config file.
-
forward
(data, net_G_output, recon=True)[source]¶ Improved FUNIT discriminator forward function.
- Parameters
data (dict) – Training data at the current iteration.
net_G_output (dict) – Fake data generated at the current iteration.
recon (bool) – If
True
, also classifies reconstructed images.
-
training
= None¶
-
class
imaginaire.discriminators.funit.
ResDiscriminator
(image_channels=3, num_classes=119, num_filters=64, max_num_filters=1024, num_layers=6, padding_mode='reflect', weight_norm_type='', **kwargs)[source]¶ Bases:
torch.nn.modules.module.Module
Residual discriminator architecture used in the FUNIT paper.
-
forward
(images, labels=None)[source]¶ Forward function of the projection discriminator.
- Parameters
images (image tensor) – Images inputted to the discriminator.
labels (long int tensor) – Class labels of the images.
-
training
= None¶
-
imaginaire.discriminators.gancraft module¶
-
class
imaginaire.discriminators.gancraft.
Discriminator
(dis_cfg, data_cfg)[source]¶ Bases:
torch.nn.modules.module.Module
Multi-resolution patch discriminator. Based on FPSE discriminator but with N+1 labels.
- Parameters
dis_cfg (obj) – Discriminator definition part of the yaml config file.
data_cfg (obj) – Data definition part of the yaml config file.
-
forward
(data, net_G_output, weights=None, incl_real=False, incl_pseudo_real=False)[source]¶ GANcraft discriminator forward.
- Parameters
data (dict) –
data (N x C1 x H x W tensor) : Ground truth images.
label (N x C2 x H x W tensor) : Semantic representations.
z (N x style_dims tensor): Gaussian random noise.
net_G_output (dict) –
fake_images (N x C1 x H x W tensor) : Fake images.
- Returns
real_outputs (list): list of output tensors produced by individual patch discriminators for real images.
real_features (list): list of lists of features produced by individual patch discriminators for real images.
fake_outputs (list): list of output tensors produced by individual patch discriminators for fake images.
fake_features (list): list of lists of features produced by individual patch discriminators for fake images.
- Return type
output_x (dict)
-
training
= None¶
-
class
imaginaire.discriminators.gancraft.
FPSEDiscriminator
(num_input_channels, num_labels, num_filters, kernel_size, weight_norm_type, activation_norm_type, do_multiscale, smooth_resample, no_label_except_largest_scale)[source]¶ Bases:
torch.nn.modules.module.Module
-
forward
(images, segmaps, weights=None)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
static
smooth_interp
(x, size)[source]¶ Smooth interpolation of segmentation maps.
- Parameters
x (4D tensor) – Segmentation maps.
size (2D list) – Target size (H, W).
-
training
= None¶
-
imaginaire.discriminators.mlp_multiclass module¶
-
class
imaginaire.discriminators.mlp_multiclass.
Discriminator
(dis_cfg, data_cfg)[source]¶ Bases:
torch.nn.modules.module.Module
Multi-layer Perceptron Classifier constructor.
- Parameters
dis_cfg (obj) – Discriminator definition part of the yaml config file.
data_cfg (obj) – Data definition part of the yaml config file
-
forward
(data)[source]¶ Patch Discriminator forward.
- Parameters
data (dict) –
data (N x -1 tensor): We will reshape the tensor to this format.
- Returns
results (N x C tensor): Output scores before softmax.
- Return type
(dict)
-
training
= None¶
imaginaire.discriminators.multires_patch module¶
-
class
imaginaire.discriminators.multires_patch.
Discriminator
(dis_cfg, data_cfg)[source]¶ Bases:
torch.nn.modules.module.Module
Multi-resolution patch discriminator.
- Parameters
dis_cfg (obj) – Discriminator definition part of the yaml config file.
data_cfg (obj) – Data definition part of the yaml config file.
-
forward
(data, net_G_output, real=True)[source]¶ SPADE Generator forward.
- Parameters
data (dict) –
data (N x C1 x H x W tensor) : Ground truth images.
label (N x C2 x H x W tensor) : Semantic representations.
z (N x style_dims tensor): Gaussian random noise.
net_G_output (dict) – fake_images (N x C1 x H x W tensor) : Fake images.
real (bool) – If
True
, also classifies real images. Otherwise it only classifies generated images to save computation during the generator update.
- Returns
real_outputs (list): list of output tensors produced by
individual patch discriminators for real images.
real_features (list): list of lists of features produced by individual patch discriminators for real images.
fake_outputs (list): list of output tensors produced by individual patch discriminators for fake images.
fake_features (list): list of lists of features produced by individual patch discriminators for fake images.
- Return type
(tuple)
-
training
= None¶
-
class
imaginaire.discriminators.multires_patch.
MultiResPatchDiscriminator
(num_discriminators=3, kernel_size=3, num_image_channels=3, num_filters=64, num_layers=4, max_num_filters=512, activation_norm_type='', weight_norm_type='', **kwargs)[source]¶ Bases:
torch.nn.modules.module.Module
Multi-resolution patch discriminator.
- Parameters
num_discriminators (int) – Num. of discriminators (one per scale).
kernel_size (int) – Convolution kernel size.
num_image_channels (int) – Num. of channels in the real/fake image.
num_filters (int) – Num. of base filters in a layer.
num_layers (int) – Num. of layers for the patch discriminator.
max_num_filters (int) – Maximum num. of filters in a layer.
activation_norm_type (str) – batch_norm/instance_norm/none/….
weight_norm_type (str) – none/spectral_norm/weight_norm
-
forward
(input_x)[source]¶ Multi-resolution patch discriminator forward.
- Parameters
input_x (tensor) – Input images.
- Returns
output_list (list): list of output tensors produced by individual patch discriminators.
features_list (list): list of lists of features produced by individual patch discriminators.
input_list (list): list of downsampled input images.
- Return type
(tuple)
-
training
= None¶
-
class
imaginaire.discriminators.multires_patch.
NLayerPatchDiscriminator
(kernel_size, num_input_channels, num_filters, num_layers, max_num_filters, activation_norm_type, weight_norm_type)[source]¶ Bases:
torch.nn.modules.module.Module
Patch Discriminator constructor.
- Parameters
kernel_size (int) – Convolution kernel size.
num_input_channels (int) – Num. of channels in the real/fake image.
num_filters (int) – Num. of base filters in a layer.
num_layers (int) – Num. of layers for the patch discriminator.
max_num_filters (int) – Maximum num. of filters in a layer.
activation_norm_type (str) – batch_norm/instance_norm/none/….
weight_norm_type (str) – none/spectral_norm/weight_norm
-
forward
(input_x)[source]¶ Patch Discriminator forward.
- Parameters
input_x (N x C x H1 x W2 tensor) – Concatenation of images and semantic representations.
- Returns
output (N x 1 x H2 x W2 tensor): Discriminator output value. Before the sigmoid when using NSGAN.
features (list): lists of tensors of the intermediate activations.
- Return type
(tuple)
-
training
= None¶
Bases:
torch.nn.modules.module.Module
Multi-resolution patch discriminator with shared weights.
- Parameters
num_discriminators (int) – Num. of discriminators (one per scale).
kernel_size (int) – Convolution kernel size.
num_image_channels (int) – Num. of channels in the real/fake image.
num_filters (int) – Num. of base filters in a layer.
num_layers (int) – Num. of layers for the patch discriminator.
max_num_filters (int) – Maximum num. of filters in a layer.
activation_norm_type (str) – batch_norm/instance_norm/none/….
weight_norm_type (str) – none/spectral_norm/weight_norm
Multi-resolution patch discriminator forward.
- Parameters
input_x (tensor) – Input images.
- Returns
output_list (list): list of output tensors produced by individual patch discriminators.
features_list (list): list of lists of features produced by individual patch discriminators.
input_list (list): list of downsampled input images.
- Return type
(tuple)
imaginaire.discriminators.munit module¶
-
class
imaginaire.discriminators.munit.
Discriminator
(dis_cfg, data_cfg)[source]¶ Bases:
torch.nn.modules.module.Module
MUNIT discriminator. It can be either a multi-resolution patch discriminator like in the original implementation, or a global residual discriminator.
- Parameters
dis_cfg (obj) – Discriminator definition part of the yaml config file.
data_cfg (obj) – Data definition part of the yaml config file
-
forward
(data, net_G_output, gan_recon=False, real=True)[source]¶ Returns the output of the discriminator.
- Parameters
data (dict) –
images_a (tensor) : Images in domain A.
images_b (tensor) : Images in domain B.
net_G_output (dict) –
images_ab (tensor) : Images translated from domain A to B by the generator.
images_ba (tensor) : Images translated from domain B to A by the generator.
images_aa (tensor) : Reconstructed images in domain A.
images_bb (tensor) : Reconstructed images in domain B.
gan_recon (bool) – If
True
, also classifies reconstructed images.real (bool) – If
True
, also classifies real images. Otherwise it only classifies generated images to save computation during the generator update.
- Returns
out_ab (tensor): Output of the discriminator for images translated from domain A to B by the generator.
out_ab (tensor): Output of the discriminator for images translated from domain B to A by the generator.
fea_ab (tensor): Intermediate features of the discriminator for images translated from domain B to A by the generator.
fea_ba (tensor): Intermediate features of the discriminator for images translated from domain A to B by the generator.
out_a (tensor): Output of the discriminator for images in domain A.
out_b (tensor): Output of the discriminator for images in domain B.
fea_a (tensor): Intermediate features of the discriminator for images in domain A.
fea_b (tensor): Intermediate features of the discriminator for images in domain B.
out_aa (tensor): Output of the discriminator for reconstructed images in domain A.
out_bb (tensor): Output of the discriminator for reconstructed images in domain B.
fea_aa (tensor): Intermediate features of the discriminator for reconstructed images in domain A.
fea_bb (tensor): Intermediate features of the discriminator for reconstructed images in domain B.
- Return type
(dict)
-
training
= None¶
imaginaire.discriminators.residual module¶
-
class
imaginaire.discriminators.residual.
ResDiscriminator
(image_channels=3, num_filters=64, max_num_filters=512, first_kernel_size=1, num_layers=4, padding_mode='zeros', activation_norm_type='', weight_norm_type='', aggregation='conv', order='pre_act', anti_aliased=False, **kwargs)[source]¶ Bases:
torch.nn.modules.module.Module
Global residual discriminator.
- Parameters
image_channels (int) – Num. of channels in the real/fake image.
num_filters (int) – Num. of base filters in a layer.
max_num_filters (int) – Maximum num. of filters in a layer.
first_kernel_size (int) – Kernel size in the first layer.
num_layers (int) – Num. of layers in discriminator.
padding_mode (str) – Padding mode.
activation_norm_type (str) – Type of activation normalization.
'none'
,'instance'
,'batch'
,'sync_batch'
.weight_norm_type (str) – Type of weight normalization.
'none'
,'spectral'
, or'weight'
.aggregation (str) – Method to aggregate features across different locations in the final layer.
'conv'
, or'pool'
.order (str) – Order of operations in the residual link.
anti_aliased (bool) – If
True
, uses anti-aliased pooling.
-
forward
(images)[source]¶ Multi-resolution patch discriminator forward.
- Parameters
images (tensor) – Input images.
- Returns
outputs (tensor): Output of the discriminator.
features (tensor): Intermediate features of the discriminator.
images (tensor): Input images.
- Return type
(tuple)
-
training
= None¶
imaginaire.discriminators.spade module¶
-
class
imaginaire.discriminators.spade.
Discriminator
(dis_cfg, data_cfg)[source]¶ Bases:
torch.nn.modules.module.Module
Multi-resolution patch discriminator.
- Parameters
dis_cfg (obj) – Discriminator definition part of the yaml config file.
data_cfg (obj) – Data definition part of the yaml config file.
-
forward
(data, net_G_output)[source]¶ SPADE discriminator forward.
- Parameters
data (dict) –
data (N x C1 x H x W tensor) : Ground truth images.
label (N x C2 x H x W tensor) : Semantic representations.
z (N x style_dims tensor): Gaussian random noise.
net_G_output (dict) – fake_images (N x C1 x H x W tensor) : Fake images.
- Returns
real_outputs (list): list of output tensors produced by individual patch discriminators for real images.
real_features (list): list of lists of features produced by individual patch discriminators for real images.
fake_outputs (list): list of output tensors produced by individual patch discriminators for fake images.
fake_features (list): list of lists of features produced by individual patch discriminators for fake images.
- Return type
(dict)
-
training
= None¶
imaginaire.discriminators.unit module¶
-
class
imaginaire.discriminators.unit.
Discriminator
(dis_cfg, data_cfg)[source]¶ Bases:
torch.nn.modules.module.Module
UNIT discriminator. It can be either a multi-resolution patch discriminator like in the original implementation, or a global residual discriminator.
- Parameters
dis_cfg (obj) – Discriminator definition part of the yaml config file.
data_cfg (obj) – Data definition part of the yaml config file
-
forward
(data, net_G_output, gan_recon=False, real=True)[source]¶ Returns the output of the discriminator.
- Parameters
data (dict) –
images_a (tensor) : Images in domain A.
images_b (tensor) : Images in domain B.
net_G_output (dict) –
images_ab (tensor) : Images translated from domain A to B by the generator.
images_ba (tensor) : Images translated from domain B to A by the generator.
images_aa (tensor) : Reconstructed images in domain A.
images_bb (tensor) : Reconstructed images in domain B.
gan_recon (bool) – If
True
, also classifies reconstructed images.real (bool) – If
True
, also classifies real images. Otherwise it only classifies generated images to save computation during the generator update.
- Returns
out_ab (tensor): Output of the discriminator for images translated from domain A to B by the generator.
out_ab (tensor): Output of the discriminator for images translated from domain B to A by the generator.
fea_ab (tensor): Intermediate features of the discriminator for images translated from domain B to A by the generator.
fea_ba (tensor): Intermediate features of the discriminator for images translated from domain A to B by the generator.
out_a (tensor): Output of the discriminator for images in domain A.
out_b (tensor): Output of the discriminator for images in domain B.
fea_a (tensor): Intermediate features of the discriminator for images in domain A.
fea_b (tensor): Intermediate features of the discriminator for images in domain B.
out_aa (tensor): Output of the discriminator for reconstructed images in domain A.
out_bb (tensor): Output of the discriminator for reconstructed images in domain B.
fea_aa (tensor): Intermediate features of the discriminator for reconstructed images in domain A.
fea_bb (tensor): Intermediate features of the discriminator for reconstructed images in domain B.
- Return type
(dict)
-
training
= None¶