# -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import io import os import time import torch.distributed as dist import torch.utils.data as data from PIL import Image from .zipreader import is_zip_path, ZipReader def has_file_allowed_extension(filename, extensions): """Checks if a file is an allowed extension. Args: filename (string): path to a file Returns: bool: True if the filename ends with a known image extension """ filename_lower = filename.lower() return any(filename_lower.endswith(ext) for ext in extensions) def find_classes(dir): classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] classes.sort() class_to_idx = {classes[i]: i for i in range(len(classes))} return classes, class_to_idx def make_dataset(dir, class_to_idx, extensions): images = [] dir = os.path.expanduser(dir) for target in sorted(os.listdir(dir)): d = os.path.join(dir, target) if not os.path.isdir(d): continue for root, _, fnames in sorted(os.walk(d)): for fname in sorted(fnames): if has_file_allowed_extension(fname, extensions): path = os.path.join(root, fname) item = (path, class_to_idx[target]) images.append(item) return images def make_dataset_with_ann(ann_file, img_prefix, extensions): images = [] with open(ann_file, "r") as f: contents = f.readlines() for line_str in contents: path_contents = [c for c in line_str.split('\t')] im_file_name = path_contents[0] class_index = int(path_contents[1]) assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions item = (os.path.join(img_prefix, im_file_name), class_index) images.append(item) return images class DatasetFolder(data.Dataset): """A generic data loader where the samples are arranged in this way: :: root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/asd932_.ext Args: root (string): Root directory path. loader (callable): A function to load a sample given its path. extensions (list[string]): A list of allowed extensions. transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. Attributes: samples (list): List of (sample path, class_index) tuples """ def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, cache_mode="no"): # image folder mode if ann_file == '': _, class_to_idx = find_classes(root) samples = make_dataset(root, class_to_idx, extensions) # zip mode else: samples = make_dataset_with_ann(os.path.join(root, ann_file), os.path.join(root, img_prefix), extensions) if len(samples) == 0: raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + "Supported extensions are: " + ",".join(extensions))) self.root = root self.loader = loader self.extensions = extensions self.samples = samples self.labels = [y_1k for _, y_1k in samples] self.classes = list(set(self.labels)) self.transform = transform self.target_transform = target_transform self.cache_mode = cache_mode if self.cache_mode != "no": self.init_cache() def init_cache(self): assert self.cache_mode in ["part", "full"] n_sample = len(self.samples) global_rank = dist.get_rank() world_size = dist.get_world_size() samples_bytes = [None for _ in range(n_sample)] start_time = time.time() for index in range(n_sample): if index % (n_sample // 10) == 0: t = time.time() - start_time print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') start_time = time.time() path, target = self.samples[index] if self.cache_mode == "full": samples_bytes[index] = (ZipReader.read(path), target) elif self.cache_mode == "part" and index % world_size == global_rank: samples_bytes[index] = (ZipReader.read(path), target) else: samples_bytes[index] = (path, target) self.samples = samples_bytes def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self): return len(self.samples) def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) fmt_str += ' Root Location: {}\n'.format(self.root) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] def pil_loader(path): # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) if isinstance(path, bytes): img = Image.open(io.BytesIO(path)) elif is_zip_path(path): data = ZipReader.read(path) img = Image.open(io.BytesIO(data)) else: with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') return img.convert('RGB') def accimage_loader(path): import accimage try: return accimage.Image(path) except IOError: # Potentially a decoding problem, fall back to PIL.Image return pil_loader(path) def default_img_loader(path): from torchvision import get_image_backend if get_image_backend() == 'accimage': return accimage_loader(path) else: return pil_loader(path) class CachedImageFolder(DatasetFolder): """A generic data loader where the images are arranged in this way: :: root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. Attributes: imgs (list): List of (image path, class_index) tuples """ def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, loader=default_img_loader, cache_mode="no"): super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, ann_file=ann_file, img_prefix=img_prefix, transform=transform, target_transform=target_transform, cache_mode=cache_mode) self.imgs = self.samples def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is class_index of the target class. """ path, target = self.samples[index] image = self.loader(path) if self.transform is not None: img = self.transform(image) else: img = image if self.target_transform is not None: target = self.target_transform(target) return img, target