Source code for imaginaire.evaluation.common

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import math
import os
from functools import partial
import torch
import torch.distributed as dist
from torch import nn
from torch.nn import functional as F
from torchvision.models import inception_v3
from cleanfid.features import feature_extractor
from cleanfid.resize import build_resizer

from imaginaire.evaluation.lpips import get_lpips_model
from imaginaire.evaluation.segmentation import get_segmentation_hist_model, get_miou
from imaginaire.evaluation.caption import get_image_encoder, get_r_precision
from imaginaire.evaluation.pretrained import TFInceptionV3, InceptionV3, Vgg16, SwAV
from imaginaire.utils.distributed import (dist_all_gather_tensor, get_rank,
                                          get_world_size, is_master,
                                          is_local_master)
from imaginaire.utils.distributed import master_only_print
from imaginaire.utils.misc import apply_imagenet_normalization, to_cuda


[docs]@torch.no_grad() def compute_all_metrics(act_dir, data_loader, net_G, key_real='images', key_fake='fake_images', sample_size=None, preprocess=None, is_video=False, few_shot_video=False, kid_num_subsets=1, kid_subset_size=None, key_prefix='', prdc_k=5, metrics=None, dataset_name='', aws_credentials=None, **kwargs): r""" Args: act_dir (string): Path to a directory to temporarily save feature activations. data_loader (obj): PyTorch dataloader object. net_G (obj): The generator module. key_real (str): Dictionary key value for the real data. key_fake (str): Dictionary key value for the fake data. sample_size (int or None): How many samples to use for FID. preprocess (func or None): Pre-processing function to use. is_video (bool): Whether we are handling video sequences. few_shot_video (bool): If ``True``, uses few-shot video synthesis. kid_num_subsets (int): Number of subsets for KID evaluation. kid_subset_size (int or None): The number of samples in each subset for KID evaluation. key_prefix (string): Add this string before all keys of the output dictionary. prdc_k (int): The K used for computing K-NN when evaluating precision/recall/density/coverage. metrics (list of strings): Which metrics we want to evaluate. dataset_name (string): The name of the dataset, currently only used to determine which segmentation network to use for segmentation evaluation. Returns: batch_y (tensor): Inception features of the current batch. Note that only the master gpu will get it. """ from imaginaire.evaluation.fid import _calculate_frechet_distance from imaginaire.evaluation.kid import _polynomial_mmd_averages from imaginaire.evaluation.prdc import _get_prdc from imaginaire.evaluation.msid import _get_msid from imaginaire.evaluation.knn import _get_1nn_acc if metrics is None: metrics = [] act_path = os.path.join(act_dir, 'activations_real.pt') # Get feature activations and other outputs computed from fake images. output_module_dict = nn.ModuleDict() if "seg_mIOU" in metrics: output_module_dict["seg_mIOU"] = get_segmentation_hist_model(dataset_name, aws_credentials) if "caption_rprec" in metrics: output_module_dict["caption_rprec"] = get_image_encoder(aws_credentials) if "LPIPS" in metrics: output_module_dict["LPIPS"] = get_lpips_model() fake_outputs = get_outputs( data_loader, key_real, key_fake, net_G, sample_size, preprocess, output_module_dict=output_module_dict, **kwargs ) fake_act = fake_outputs["activations"] # Get feature activations computed from real images. real_act = load_or_compute_activations( act_path, data_loader, key_real, key_fake, None, sample_size, preprocess, is_video=is_video, few_shot_video=few_shot_video, **kwargs ) metrics_from_activations = { "1NN": _get_1nn_acc, "MSID": _get_msid, "FID": _calculate_frechet_distance, "KID": partial(_polynomial_mmd_averages, n_subsets=kid_num_subsets, subset_size=kid_subset_size, ret_var=True), "PRDC": partial(_get_prdc, nearest_k=prdc_k) } other_metrics = { "seg_mIOU": get_miou, "caption_rprec": get_r_precision, "LPIPS": lambda x: {"LPIPS": torch.mean(x).item()} } all_metrics = {} if is_master(): for metric in metrics: if metric in metrics_from_activations: metric_function = metrics_from_activations[metric] metric_dict = metric_function(real_act, fake_act) elif metric in other_metrics: metric_function = other_metrics[metric] if fake_outputs[metric] is not None: metric_dict = metric_function(fake_outputs[metric]) else: print(f"{metric} is not implemented!") raise NotImplementedError for k, v in metric_dict.items(): all_metrics.update({key_prefix + k: v}) if dist.is_initialized(): dist.barrier() return all_metrics
[docs]@torch.no_grad() def compute_all_metrics_data(data_loader_a, data_loader_b, key_a='images', key_b='images', sample_size=None, preprocess=None, kid_num_subsets=1, kid_subset_size=None, key_prefix='', prdc_k=5, metrics=None, dataset_name='', aws_credentials=None, **kwargs): r""" Args: act_dir (string): Path to a directory to temporarily save feature activations. data_loader (obj): PyTorch dataloader object. net_G (obj): The generator module. key_a (str): Dictionary key value for the real data. key_b (str): Dictionary key value for the fake data. sample_size (int or None): How many samples to use for FID. preprocess (func or None): Pre-processing function to use. is_video (bool): Whether we are handling video sequences. few_shot_video (bool): If ``True``, uses few-shot video synthesis. kid_num_subsets (int): Number of subsets for KID evaluation. kid_subset_size (int or None): The number of samples in each subset for KID evaluation. key_prefix (string): Add this string before all keys of the output dictionary. prdc_k (int): The K used for computing K-NN when evaluating precision/recall/density/coverage. metrics (list of strings): Which metrics we want to evaluate. dataset_name (string): The name of the dataset, currently only used to determine which segmentation network to use for segmentation evaluation. Returns: batch_y (tensor): Inception features of the current batch. Note that only the master gpu will get it. """ from imaginaire.evaluation.fid import _calculate_frechet_distance from imaginaire.evaluation.kid import _polynomial_mmd_averages from imaginaire.evaluation.prdc import _get_prdc from imaginaire.evaluation.msid import _get_msid from imaginaire.evaluation.knn import _get_1nn_acc if metrics is None: metrics = [] min_data_size = min(len(data_loader_a.dataset), len(data_loader_b.dataset)) if sample_size is None: sample_size = min_data_size else: sample_size = min(sample_size, min_data_size) # Get feature activations and other outputs computed from fake images. output_module_dict = nn.ModuleDict() if "seg_mIOU" in metrics: output_module_dict["seg_mIOU"] = get_segmentation_hist_model(dataset_name, aws_credentials) if "caption_rprec" in metrics: output_module_dict["caption_rprec"] = get_image_encoder(aws_credentials) if "LPIPS" in metrics: output_module_dict["LPIPS"] = get_lpips_model() fake_outputs = get_outputs( data_loader_b, key_a, key_b, None, sample_size, preprocess, output_module_dict=output_module_dict, **kwargs ) act_b = fake_outputs["activations"] act_a = load_or_compute_activations( None, data_loader_a, key_a, key_b, None, sample_size, preprocess, output_module_dict=output_module_dict, **kwargs ) # act_b = load_or_compute_activations( # None, data_loader_b, key_a, key_b, None, sample_size, preprocess, # output_module_dict=output_module_dict, generate_twice=generate_twice, **kwargs # ) metrics_from_activations = { "1NN": _get_1nn_acc, "MSID": _get_msid, "FID": _calculate_frechet_distance, "KID": partial(_polynomial_mmd_averages, n_subsets=kid_num_subsets, subset_size=kid_subset_size, ret_var=True), "PRDC": partial(_get_prdc, nearest_k=prdc_k) } other_metrics = { "seg_mIOU": get_miou, "caption_rprec": get_r_precision, "LPIPS": lambda x: {"LPIPS": torch.mean(x).item()} } all_metrics = {} if is_master(): for metric in metrics: if metric in metrics_from_activations: metric_function = metrics_from_activations[metric] metric_dict = metric_function(act_a, act_b) elif metric in other_metrics: metric_function = other_metrics[metric] if fake_outputs[metric] is not None: metric_dict = metric_function(fake_outputs[metric]) else: print(f"{metric} is not implemented!") raise NotImplementedError for k, v in metric_dict.items(): all_metrics.update({key_prefix + k: v}) if dist.is_initialized(): dist.barrier() return all_metrics
[docs]@torch.no_grad() def get_activations(data_loader, key_real, key_fake, generator=None, sample_size=None, preprocess=None, align_corners=True, network='inception', **kwargs): r"""Compute activation values and pack them in a list. Args: data_loader (obj): PyTorch dataloader object. key_real (str): Dictionary key value for the real data. key_fake (str): Dictionary key value for the fake data. generator (obj): PyTorch trainer network. sample_size (int): How many samples to use for FID. preprocess (func): Pre-processing function to use. align_corners (bool): The ``'align_corners'`` parameter to be used for `torch.nn.functional.interpolate`. Returns: batch_y (tensor): Inception features of the current batch. Note that only the master gpu will get it. """ if dist.is_initialized() and not is_local_master(): # Make sure only the first process in distributed training downloads # the model, and the others will use the cache # noinspection PyUnresolvedReferences torch.distributed.barrier() if network == 'tf_inception': model = TFInceptionV3() elif network == 'inception': model = InceptionV3() elif network == 'vgg16': model = Vgg16() elif network == 'swav': model = SwAV() elif network == 'clean_inception': model = CleanInceptionV3() else: raise NotImplementedError(f'Network "{network}" is not supported!') if dist.is_initialized() and is_local_master(): # Make sure only the first process in distributed training downloads # the model, and the others will use the cache # noinspection PyUnresolvedReferences dist.barrier() model = model.to('cuda').eval() world_size = get_world_size() batch_y = [] # Iterate through the dataset to compute the activation. for it, data in enumerate(data_loader): data = to_cuda(data) # Preprocess the data. if preprocess is not None: data = preprocess(data) # Load real data if the generator is not specified. if generator is None: images = data[key_real] else: # Compute the generated image. net_G_output = generator(data, **kwargs) images = net_G_output[key_fake] # Clamp the image for models that do not set the output to between # -1, 1. For models that employ tanh, this has no effect. images.clamp_(-1, 1) y = model(images, align_corners=align_corners) batch_y.append(y) if sample_size is not None and \ data_loader.batch_size * world_size * (it + 1) >= sample_size: # Reach the number of samples we need. break batch_y = torch.cat(dist_all_gather_tensor(torch.cat(batch_y))) if sample_size is not None: batch_y = batch_y[:sample_size] print(f"Computed feature activations of size {batch_y.shape}") return batch_y
[docs]class CleanInceptionV3(nn.Module): def __init__(self): super().__init__() self.model = feature_extractor(name="torchscript_inception", resize_inside=False)
[docs] def forward(self, img_batch, transform=True, **_kwargs): if transform: # Assume the input is (-1, 1). We transform it to (0, 255) and round it to the closest integer. img_batch = torch.round(255 * (0.5 * img_batch + 0.5)) resized_batch = clean_resize(img_batch) return self.model(resized_batch)
[docs]def clean_resize(img_batch): # Resize images from arbitrary resolutions to 299x299. batch_size = img_batch.size(0) img_batch = img_batch.cpu().numpy() fn_resize = build_resizer('clean') resized_batch = torch.zeros(batch_size, 3, 299, 299, device='cuda') for idx in range(batch_size): curr_img = img_batch[idx] img_np = curr_img.transpose((1, 2, 0)) img_resize = fn_resize(img_np) resized_batch[idx] = torch.tensor(img_resize.transpose((2, 0, 1)), device='cuda') resized_batch = resized_batch.cuda() return resized_batch
[docs]@torch.no_grad() def get_outputs(data_loader, key_real, key_fake, generator=None, sample_size=None, preprocess=None, align_corners=True, network='inception', output_module_dict=None, **kwargs): r"""Compute activation values and pack them in a list. Args: data_loader (obj): PyTorch dataloader object. key_real (str): Dictionary key value for the real data. key_fake (str): Dictionary key value for the fake data. generator (obj): PyTorch trainer network. sample_size (int): How many samples to use for FID. preprocess (func): Pre-processing function to use. align_corners (bool): The ``'align_corners'`` parameter to be used for `torch.nn.functional.interpolate`. Returns: batch_y (tensor): Inception features of the current batch. Note that only the master gpu will get it. """ if output_module_dict is None: output_module_dict = nn.ModuleDict() if dist.is_initialized() and not is_local_master(): # Make sure only the first process in distributed training downloads # the model, and the others will use the cache # noinspection PyUnresolvedReferences torch.distributed.barrier() if network == 'tf_inception': model = TFInceptionV3() elif network == 'inception': model = InceptionV3() elif network == 'vgg16': model = Vgg16() elif network == 'swav': model = SwAV() elif network == 'clean_inception': model = CleanInceptionV3() else: raise NotImplementedError(f'Network "{network}" is not supported!') if dist.is_initialized() and is_local_master(): # Make sure only the first process in distributed training downloads # the model, and the others will use the cache # noinspection PyUnresolvedReferences dist.barrier() model = model.to('cuda').eval() world_size = get_world_size() output = {} for k in output_module_dict.keys(): output[k] = [] output["activations"] = [] # Iterate through the dataset to compute the activation. for it, data in enumerate(data_loader): data = to_cuda(data) # Preprocess the data. if preprocess is not None: data = preprocess(data) # Load real data if the generator is not specified. if generator is None: images = data[key_real] else: # Compute the generated image. net_G_output = generator(data, **kwargs) images = net_G_output[key_fake] for metric_name, metric_module in output_module_dict.items(): if metric_module is not None: if metric_name == 'LPIPS': assert generator is not None net_G_output_another = generator(data, **kwargs) images_another = net_G_output_another[key_fake] output[metric_name].append(metric_module(images, images_another)) else: output[metric_name].append(metric_module(data, images, align_corners=align_corners)) # Clamp the image for models that do not set the output to between # -1, 1. For models that employ tanh, this has no effect. images.clamp_(-1, 1) y = model(images, align_corners=align_corners) output["activations"].append(y) if sample_size is not None and data_loader.batch_size * world_size * (it + 1) >= sample_size: # Reach the number of samples we need. break for k, v in output.items(): if len(v) > 0: output[k] = torch.cat(dist_all_gather_tensor(torch.cat(v)))[:sample_size] else: output[k] = None return output
[docs]@torch.no_grad() def get_video_activations(data_loader, key_real, key_fake, trainer=None, sample_size=None, preprocess=None, few_shot=False): r"""Compute activation values and pack them in a list. We do not do all reduce here. Args: data_loader (obj): PyTorch dataloader object. key_real (str): Dictionary key value for the real data. key_fake (str): Dictionary key value for the fake data. trainer (obj): Trainer. Video generation is more involved, we rely on the "reset" and "test" function to conduct the evaluation. sample_size (int): For computing video activation, we will use . preprocess (func): The preprocess function to be applied to the data. few_shot (bool): If ``True``, uses the few-shot setting. Returns: batch_y (tensor): Inception features of the current batch. Note that only the master gpu will get it. """ inception = inception_init() batch_y = [] # We divide video sequences to different GPUs for testing. num_sequences = data_loader.dataset.num_inference_sequences() if sample_size is None: num_videos_to_test = 10 num_frames_per_video = 5 else: num_videos_to_test, num_frames_per_video = sample_size if num_videos_to_test == -1: num_videos_to_test = num_sequences else: num_videos_to_test = min(num_videos_to_test, num_sequences) master_only_print('Number of videos used for evaluation: {}'.format(num_videos_to_test)) master_only_print('Number of frames per video used for evaluation: {}'.format(num_frames_per_video)) world_size = get_world_size() if num_videos_to_test < world_size: seq_to_run = [get_rank() % num_videos_to_test] else: num_videos_to_test = num_videos_to_test // world_size * world_size seq_to_run = range(get_rank(), num_videos_to_test, world_size) for sequence_idx in seq_to_run: data_loader = set_sequence_idx(few_shot, data_loader, sequence_idx) if trainer is not None: trainer.reset() for it, data in enumerate(data_loader): if few_shot and it == 0: continue if it >= num_frames_per_video: break # preprocess the data is preprocess is not none. if trainer is not None: data = trainer.pre_process(data) elif preprocess is not None: data = preprocess(data) data = to_cuda(data) if trainer is None: images = data[key_real][:, -1] else: net_G_output = trainer.test_single(data) images = net_G_output[key_fake] y = inception_forward(inception, images) batch_y += [y] batch_y = torch.cat(batch_y) batch_y = dist_all_gather_tensor(batch_y) if is_local_master(): batch_y = torch.cat(batch_y) return batch_y
[docs]def inception_init(): inception = inception_v3(pretrained=True, transform_input=False) inception = inception.to('cuda') inception.eval() inception.fc = torch.nn.Sequential() return inception
[docs]def inception_forward(inception, images): images.clamp_(-1, 1) images = apply_imagenet_normalization(images) images = F.interpolate(images, size=(299, 299), mode='bicubic', align_corners=True) return inception(images)
[docs]def gather_tensors(batch_y): batch_y = torch.cat(batch_y) batch_y = dist_all_gather_tensor(batch_y) if is_local_master(): batch_y = torch.cat(batch_y) return batch_y
[docs]def set_sequence_idx(few_shot, data_loader, sequence_idx): r"""Get sequence index Args: few_shot (bool): If ``True``, uses the few-shot setting. data_loader: dataloader object sequence_idx (int): which sequence to use. """ if few_shot: data_loader.dataset.set_inference_sequence_idx(sequence_idx, sequence_idx, 0) else: data_loader.dataset.set_inference_sequence_idx(sequence_idx) return data_loader
[docs]def load_or_compute_activations(act_path, data_loader, key_real, key_fake, generator=None, sample_size=None, preprocess=None, is_video=False, few_shot_video=False, **kwargs): r"""Load mean and covariance from saved npy file if exists. Otherwise, compute the mean and covariance. Args: act_path (str or None): Location for the numpy file to store or to load the activations. data_loader (obj): PyTorch dataloader object. key_real (str): Dictionary key value for the real data. key_fake (str): Dictionary key value for the fake data. generator (obj): PyTorch trainer network. sample_size (int): How many samples to be used for computing the KID. preprocess (func): The preprocess function to be applied to the data. is_video (bool): Whether we are handling video sequences. few_shot_video (bool): If ``True``, uses few-shot video synthesis. Returns: (torch.Tensor) Feature activations. """ if act_path is not None and os.path.exists(act_path): # Loading precomputed activations. print('Load activations from {}'.format(act_path)) act = torch.load(act_path, map_location='cpu').cuda() else: # Compute activations. if is_video: act = get_video_activations( data_loader, key_real, key_fake, generator, sample_size, preprocess, few_shot_video, **kwargs ) else: act = get_activations( data_loader, key_real, key_fake, generator, sample_size, preprocess, **kwargs ) if act_path is not None and is_local_master(): print('Save activations to {}'.format(act_path)) if not os.path.exists(os.path.dirname(act_path)): os.makedirs(os.path.dirname(act_path), exist_ok=True) torch.save(act, act_path) return act
[docs]def compute_pairwise_distance(data_x, data_y=None, num_splits=10): r""" Args: data_x: numpy.ndarray([N, feature_dim], dtype=np.float32) data_y: numpy.ndarray([N, feature_dim], dtype=np.float32) Returns: numpy.ndarray([N, N], dtype=np.float32) of pairwise distances. """ if data_y is None: data_y = data_x num_samples = data_x.shape[0] assert data_x.shape[0] == data_y.shape[0] dists = [] for i in range(num_splits): batch_size = math.ceil(num_samples / num_splits) start_idx = i * batch_size end_idx = min((i + 1) * batch_size, num_samples) dists.append(torch.cdist(data_x[start_idx:end_idx], data_y).cpu()) dists = torch.cat(dists, dim=0) return dists
[docs]def compute_nn(input_features, k, num_splits=50): num_samples = input_features.shape[0] all_indices = [] all_values = [] for i in range(num_splits): batch_size = math.ceil(num_samples / num_splits) start_idx = i * batch_size end_idx = min((i + 1) * batch_size, num_samples) dist = torch.cdist(input_features[start_idx:end_idx], input_features) dist[:, start_idx:end_idx] += torch.diag( float('inf') * torch.ones(dist.size(0), device=dist.device) ) k_smallests, indices = torch.topk(dist, k, dim=-1, largest=False) all_indices.append(indices) all_values.append(k_smallests) return torch.cat(all_values, dim=0), torch.cat(all_indices, dim=0)