Source code for

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out
import os
import requests
import torch.distributed as dist
import torchvision.utils
from imaginaire.utils.distributed import is_master
[docs]def save_pilimage_in_jpeg(fullname, output_img): r"""Save PIL Image to JPEG. Args: fullname (str): Full save path. output_img (PIL Image): Image to be saved. """ dirname = os.path.dirname(fullname) os.makedirs(dirname, exist_ok=True), 'JPEG', quality=99)
[docs]def save_intermediate_training_results( visualization_images, logdir, current_epoch, current_iteration): r"""Save intermediate training results for debugging purpose. Args: visualization_images (tensor): Image where pixel values are in [-1, 1]. logdir (str): Where to save the image. current_epoch (int): Current training epoch. current_iteration (int): Current training iteration. """ visualization_images = (visualization_images + 1) / 2 output_filename = os.path.join( logdir, 'images', 'epoch_{:05}iteration{:09}.jpg'.format( current_epoch, current_iteration)) print('Save output images to {}'.format(output_filename)) os.makedirs(os.path.dirname(output_filename), exist_ok=True) image_grid = torchvision.utils.make_grid(, nrow=1, padding=0, normalize=False) torchvision.utils.save_image(image_grid, output_filename, nrow=1)
[docs]def download_file_from_google_drive(URL, destination): r"""Download a file from google drive. Args: URL: GDrive file ID. destination: Path to save the file. Returns: """ download_file(f"{URL}", destination)
[docs]def download_file(URL, destination): r"""Download a file from google drive or pbss by using the url. Args: URL: GDrive URL or PBSS pre-signed URL for the checkpoint. destination: Path to save the file. Returns: """ session = requests.Session() response = session.get(URL, stream=True) token = get_confirm_token(response) if token: params = {'confirm': token} response = session.get(URL, params=params, stream=True) save_response_content(response, destination)
[docs]def get_confirm_token(response): r"""Get confirm token Args: response: Check if the file exists. Returns: """ for key, value in response.cookies.items(): if key.startswith('download_warning'): return value return None
[docs]def save_response_content(response, destination): r"""Save response content Args: response: destination: Path to save the file. Returns: """ chunk_size = 32768 with open(destination, "wb") as f: for chunk in response.iter_content(chunk_size): if chunk: f.write(chunk)
[docs]def get_checkpoint(checkpoint_path, url=''): r"""Get the checkpoint path. If it does not exist yet, download it from the url. Args: checkpoint_path (str): Checkpoint path. url (str): URL to download checkpoint. Returns: (str): Full checkpoint path. """ if 'TORCH_HOME' not in os.environ: os.environ['TORCH_HOME'] = os.getcwd() save_dir = os.path.join(os.environ['TORCH_HOME'], 'checkpoints') os.makedirs(save_dir, exist_ok=True) full_checkpoint_path = os.path.join(save_dir, checkpoint_path) if not os.path.exists(full_checkpoint_path): os.makedirs(os.path.dirname(full_checkpoint_path), exist_ok=True) if is_master(): print('Downloading {}'.format(url)) if '' not in url: url = f"{url}" download_file(url, full_checkpoint_path) if dist.is_available() and dist.is_initialized(): dist.barrier() return full_checkpoint_path