Source code for imaginaire.datasets.object_store

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import io
import json

# import cv2
import boto3
from botocore.config import Config
import numpy as np
import torch.utils.data as data
from PIL import Image
import imageio
from botocore.exceptions import ClientError

from imaginaire.datasets.cache import Cache
from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS

Image.MAX_IMAGE_PIXELS = None


[docs]class ObjectStoreDataset(data.Dataset): r"""This deals with opening, and reading from an AWS S3 bucket. Args: root (str): Path to the AWS S3 bucket. aws_credentials_file (str): Path to file containing AWS credentials. data_type (str): Which data type should this dataset load? """ def __init__(self, root, aws_credentials_file, data_type='', cache=None): # Cache. self.cache = False if cache is not None: # raise NotImplementedError self.cache = Cache(cache.root, cache.size_GB) # Get bucket info, and keys to info about dataset. with open(aws_credentials_file) as fin: self.credentials = json.load(fin) parts = root.split('/') self.bucket = parts[0] self.all_filenames_key = '/'.join(parts[1:]) + '/all_filenames.json' self.metadata_key = '/'.join(parts[1:]) + '/metadata.json' # Get list of filenames. filename_info = self._get_object(self.all_filenames_key) self.sequence_list = json.loads(filename_info.decode('utf-8')) # Get length. length = 0 for _, value in self.sequence_list.items(): length += len(value) self.length = length # Read metadata. metadata_info = self._get_object(self.metadata_key) self.extensions = json.loads(metadata_info.decode('utf-8')) self.data_type = data_type print('AWS S3 bucket at %s opened.' % (root + '/' + self.data_type)) def _get_object(self, key): r"""Download object from bucket. Args: key (str): Key inside bucket. """ # Look up value in cache. object_content = self.cache.read(key) if self.cache else False if not object_content: # Either no cache used or key not found in cache. config = Config(connect_timeout=30, signature_version="s3", retries={"max_attempts": 999999}) s3 = boto3.client('s3', **self.credentials, config=config) try: s3_response_object = s3.get_object(Bucket=self.bucket, Key=key) object_content = s3_response_object['Body'].read() except Exception as e: print('%s not found' % (key)) print(e) # Save content to cache. if self.cache: self.cache.write(key, object_content) return object_content
[docs] def getitem_by_path(self, path, data_type): r"""Load data item stored for key = path. Args: path (str): Path into AWS S3 bucket, without data_type prefix. data_type (str): Key into self.extensions e.g. data/data_segmaps/... Returns: img (PIL.Image) or buf (str): Contents of LMDB value for this key. """ # Figure out decoding params. ext = self.extensions[data_type] is_image = False is_hdr = False parts = path.split('/') key = parts[0] + '/' + data_type + '/' + '/'.join(parts[1:]) + '.' + ext if ext in IMG_EXTENSIONS: is_image = True if 'tif' in ext: _, mode = np.uint16, -1 elif 'JPEG' in ext or 'JPG' in ext \ or 'jpeg' in ext or 'jpg' in ext: _, mode = np.uint8, 3 else: _, mode = np.uint8, -1 elif ext in HDR_IMG_EXTENSIONS: is_hdr = True else: is_image = False # Get value from key. buf = self._get_object(key) # Decode and return. if is_image: # This is totally a hack. # We should have a better way to handle grayscale images. img = Image.open(io.BytesIO(buf)) if mode == 3: img = img.convert('RGB') return img elif is_hdr: try: imageio.plugins.freeimage.download() img = imageio.imread(buf) except Exception: print(path) return img # Return a numpy array else: return buf
def __len__(self): r"""Return number of keys in LMDB dataset.""" return self.length