Source code for imaginaire.utils.lmdb

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import glob
import os

import lmdb
from tqdm import tqdm

from imaginaire.utils import path


[docs]def construct_file_path(root, data_type, sequence, filename, ext): """Get file path for our dataset structure.""" return '%s/%s/%s/%s.%s' % (root, data_type, sequence, filename, ext)
[docs]def check_and_add(filepath, key, filepaths, keys, remove_missing=False): r"""Add filepath and key to list of filepaths and keys. Args: filepath (str): Filepath to add. key (str): LMDB key for this filepath. filepaths (list): List of filepaths added so far. keys (list): List of keys added so far. remove_missing (bool): If ``True``, removes missing files, otherwise raises an error. Returns: (int): Size of file at filepath. """ if not os.path.exists(filepath): print(filepath + ' does not exist.') if remove_missing: return -1 else: raise FileNotFoundError(filepath + ' does not exist.') filepaths.append(filepath) keys.append(key) return os.path.getsize(filepath)
[docs]def write_entry(txn, key, filepath): r"""Dump binary contents of file associated with key to LMDB. Args: txn: handle to LMDB. key (str): LMDB key for this filepath. filepath (str): Filepath to add. """ with open(filepath, 'rb') as f: data = f.read() txn.put(key.encode('ascii'), data)
[docs]def build_lmdb(filepaths, keys, output_filepath, map_size, large): r"""Write out lmdb containing (key, contents of filepath) to file. Args: filepaths (list): List of filepath strings. keys (list): List of key strings associated with filepaths. output_filepath (str): Location to write LMDB to. map_size (int): Size of LMDB. large (bool): Is the dataset large? """ if large: db = lmdb.open(output_filepath, map_size=map_size, writemap=True) else: db = lmdb.open(output_filepath, map_size=map_size) txn = db.begin(write=True) print('Writing LMDB to:', output_filepath) for filepath, key in tqdm(zip(filepaths, keys), total=len(keys)): write_entry(txn, key, filepath) txn.commit()
[docs]def get_all_filenames_from_list(list_name): r"""Get all filenames from list. Args: list_name (str): Path to filename list. Returns: all_filenames (dict): Folder name for key, and filename for values. """ with open(list_name, 'rt') as f: lines = f.readlines() lines = [line.strip() for line in lines] all_filenames = dict() for line in lines: if '/' in line: file_str = line.split('/')[0:-1] folder_name = os.path.join(*file_str) image_name = line.split('/')[-1].replace('.jpg', '') else: folder_name = '.' image_name = line.replace('.jpg', '') if folder_name in all_filenames: all_filenames[folder_name].append(image_name) else: all_filenames[folder_name] = [image_name] return all_filenames
[docs]def get_lmdb_data_types(cfg): r"""Get the data types which should be put in LMDB. Args: cfg: Configuration object. """ data_types, extensions = [], [] for data_type in cfg.data.input_types: name = list(data_type.keys()) assert len(name) == 1 name = name[0] info = data_type[name] if 'computed_on_the_fly' not in info: info['computed_on_the_fly'] = False is_lmdb = not info['computed_on_the_fly'] if not is_lmdb: continue ext = info['ext'] data_types.append(name) extensions.append(ext) cfg.data.data_types = data_types cfg.data.extensions = extensions return cfg
[docs]def create_metadata(data_root=None, cfg=None, paired=None, input_list=''): r"""Main function. Args: data_root (str): Location of dataset root. cfg (object): Loaded config object. paired (bool): Paired or unpaired data. input_list (str): Path to filename containing list of inputs. Returns: (tuple): - all_filenames (dict): Key of data type, values with sequences. - extensions (dict): Extension of each data type. """ cfg = get_lmdb_data_types(cfg) # Get list of all data_types in the dataset. available_data_types = path.get_immediate_subdirectories(data_root) print(available_data_types) required_data_types = cfg.data.data_types data_exts = cfg.data.extensions # Find filenames. assert set(required_data_types).issubset(set(available_data_types)), \ print(set(required_data_types) - set(available_data_types), 'missing') # Find extensions for each data type. extensions = {} for data_type, data_ext in zip(required_data_types, data_exts): extensions[data_type] = data_ext print('Data file extensions:', extensions) if paired: if input_list != '': all_filenames = get_all_filenames_from_list(input_list) else: # Get list of all sequences in the dataset. if 'data_keypoint' in required_data_types: search_dir = 'data_keypoint' elif 'data_segmaps' in required_data_types: search_dir = 'data_segmaps' else: search_dir = required_data_types[0] print('Searching in dir: %s' % search_dir) sequences = path.get_recursive_subdirectories( os.path.join(data_root, search_dir), extensions[search_dir]) print('Found %d sequences' % (len(sequences))) # Get filenames in each sequence. all_filenames = {} for sequence in sequences: folder = '%s/%s/%s/*.%s' % ( data_root, search_dir, sequence, extensions[search_dir]) filenames = sorted(glob.glob(folder)) filenames = [ os.path.splitext(os.path.basename(filename))[0] for filename in filenames] all_filenames[sequence] = filenames total_filenames = [len(filenames) for _, filenames in all_filenames.items()] print('Found %d files' % (sum(total_filenames))) else: # Get sequences in each data type. all_filenames = {} for data_type in required_data_types: all_filenames[data_type] = {} sequences = path.get_recursive_subdirectories( os.path.join(data_root, data_type), extensions[data_type]) # Get filenames in each sequence. total_filenames = 0 for sequence in sequences: folder = '%s/%s/%s/*.%s' % ( data_root, data_type, sequence, extensions[data_type]) filenames = sorted(glob.glob(folder)) filenames = [ os.path.splitext(os.path.basename(filename))[0] for filename in filenames] all_filenames[data_type][sequence] = filenames total_filenames += len(filenames) print('Data type: %s, Found %d sequences, Found %d files' % (data_type, len(sequences), total_filenames)) return all_filenames, extensions