Source code for imaginaire.datasets.object_store

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out
import io
import json
# import cv2
import boto3
from botocore.config import Config
import numpy as np
import as data
from PIL import Image
import imageio
from botocore.exceptions import ClientError
from imaginaire.datasets.cache import Cache
[docs]class ObjectStoreDataset(data.Dataset): r"""This deals with opening, and reading from an AWS S3 bucket. Args: root (str): Path to the AWS S3 bucket. aws_credentials_file (str): Path to file containing AWS credentials. data_type (str): Which data type should this dataset load? """ def __init__(self, root, aws_credentials_file, data_type='', cache=None): # Cache. self.cache = False if cache is not None: # raise NotImplementedError self.cache = Cache(cache.root, cache.size_GB) # Get bucket info, and keys to info about dataset. with open(aws_credentials_file) as fin: self.credentials = json.load(fin) parts = root.split('/') self.bucket = parts[0] self.all_filenames_key = '/'.join(parts[1:]) + '/all_filenames.json' self.metadata_key = '/'.join(parts[1:]) + '/metadata.json' # Get list of filenames. filename_info = self._get_object(self.all_filenames_key) self.sequence_list = json.loads(filename_info.decode('utf-8')) # Get length. length = 0 for _, value in self.sequence_list.items(): length += len(value) self.length = length # Read metadata. metadata_info = self._get_object(self.metadata_key) self.extensions = json.loads(metadata_info.decode('utf-8')) self.data_type = data_type print('AWS S3 bucket at %s opened.' % (root + '/' + self.data_type)) def _get_object(self, key): r"""Download object from bucket. Args: key (str): Key inside bucket. """ # Look up value in cache. object_content = if self.cache else False if not object_content: # Either no cache used or key not found in cache. config = Config(connect_timeout=30, signature_version="s3", retries={"max_attempts": 999999}) s3 = boto3.client('s3', **self.credentials, config=config) try: s3_response_object = s3.get_object(Bucket=self.bucket, Key=key) object_content = s3_response_object['Body'].read() except Exception as e: print('%s not found' % (key)) print(e) # Save content to cache. if self.cache: self.cache.write(key, object_content) return object_content
[docs] def getitem_by_path(self, path, data_type): r"""Load data item stored for key = path. Args: path (str): Path into AWS S3 bucket, without data_type prefix. data_type (str): Key into self.extensions e.g. data/data_segmaps/... Returns: img (PIL.Image) or buf (str): Contents of LMDB value for this key. """ # Figure out decoding params. ext = self.extensions[data_type] is_image = False is_hdr = False parts = path.split('/') key = parts[0] + '/' + data_type + '/' + '/'.join(parts[1:]) + '.' + ext if ext in IMG_EXTENSIONS: is_image = True if 'tif' in ext: _, mode = np.uint16, -1 elif 'JPEG' in ext or 'JPG' in ext \ or 'jpeg' in ext or 'jpg' in ext: _, mode = np.uint8, 3 else: _, mode = np.uint8, -1 elif ext in HDR_IMG_EXTENSIONS: is_hdr = True else: is_image = False # Get value from key. buf = self._get_object(key) # Decode and return. if is_image: # This is totally a hack. # We should have a better way to handle grayscale images. img = if mode == 3: img = img.convert('RGB') return img elif is_hdr: try: img = imageio.imread(buf) except Exception: print(path) return img # Return a numpy array else: return buf
def __len__(self): r"""Return number of keys in LMDB dataset.""" return self.length