123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252 |
- # --------------------------------------------------------
- # Swin Transformer
- # Copyright (c) 2021 Microsoft
- # Licensed under The MIT License [see LICENSE for details]
- # Written by Ze Liu
- # --------------------------------------------------------
- import io
- import os
- import time
- import torch.distributed as dist
- import torch.utils.data as data
- from PIL import Image
- from .zipreader import is_zip_path, ZipReader
- def has_file_allowed_extension(filename, extensions):
- """Checks if a file is an allowed extension.
- Args:
- filename (string): path to a file
- Returns:
- bool: True if the filename ends with a known image extension
- """
- filename_lower = filename.lower()
- return any(filename_lower.endswith(ext) for ext in extensions)
- def find_classes(dir):
- classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
- classes.sort()
- class_to_idx = {classes[i]: i for i in range(len(classes))}
- return classes, class_to_idx
- def make_dataset(dir, class_to_idx, extensions):
- images = []
- dir = os.path.expanduser(dir)
- for target in sorted(os.listdir(dir)):
- d = os.path.join(dir, target)
- if not os.path.isdir(d):
- continue
- for root, _, fnames in sorted(os.walk(d)):
- for fname in sorted(fnames):
- if has_file_allowed_extension(fname, extensions):
- path = os.path.join(root, fname)
- item = (path, class_to_idx[target])
- images.append(item)
- return images
- def make_dataset_with_ann(ann_file, img_prefix, extensions):
- images = []
- with open(ann_file, "r") as f:
- contents = f.readlines()
- for line_str in contents:
- path_contents = [c for c in line_str.split('\t')]
- im_file_name = path_contents[0]
- class_index = int(path_contents[1])
- assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
- item = (os.path.join(img_prefix, im_file_name), class_index)
- images.append(item)
- return images
- class DatasetFolder(data.Dataset):
- """A generic data loader where the samples are arranged in this way: ::
- root/class_x/xxx.ext
- root/class_x/xxy.ext
- root/class_x/xxz.ext
- root/class_y/123.ext
- root/class_y/nsdf3.ext
- root/class_y/asd932_.ext
- Args:
- root (string): Root directory path.
- loader (callable): A function to load a sample given its path.
- extensions (list[string]): A list of allowed extensions.
- transform (callable, optional): A function/transform that takes in
- a sample and returns a transformed version.
- E.g, ``transforms.RandomCrop`` for images.
- target_transform (callable, optional): A function/transform that takes
- in the target and transforms it.
- Attributes:
- samples (list): List of (sample path, class_index) tuples
- """
- def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
- cache_mode="no"):
- # image folder mode
- if ann_file == '':
- _, class_to_idx = find_classes(root)
- samples = make_dataset(root, class_to_idx, extensions)
- # zip mode
- else:
- samples = make_dataset_with_ann(os.path.join(root, ann_file),
- os.path.join(root, img_prefix),
- extensions)
- if len(samples) == 0:
- raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
- "Supported extensions are: " + ",".join(extensions)))
- self.root = root
- self.loader = loader
- self.extensions = extensions
- self.samples = samples
- self.labels = [y_1k for _, y_1k in samples]
- self.classes = list(set(self.labels))
- self.transform = transform
- self.target_transform = target_transform
- self.cache_mode = cache_mode
- if self.cache_mode != "no":
- self.init_cache()
- def init_cache(self):
- assert self.cache_mode in ["part", "full"]
- n_sample = len(self.samples)
- global_rank = dist.get_rank()
- world_size = dist.get_world_size()
- samples_bytes = [None for _ in range(n_sample)]
- start_time = time.time()
- for index in range(n_sample):
- if index % (n_sample // 10) == 0:
- t = time.time() - start_time
- print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
- start_time = time.time()
- path, target = self.samples[index]
- if self.cache_mode == "full":
- samples_bytes[index] = (ZipReader.read(path), target)
- elif self.cache_mode == "part" and index % world_size == global_rank:
- samples_bytes[index] = (ZipReader.read(path), target)
- else:
- samples_bytes[index] = (path, target)
- self.samples = samples_bytes
- def __getitem__(self, index):
- """
- Args:
- index (int): Index
- Returns:
- tuple: (sample, target) where target is class_index of the target class.
- """
- path, target = self.samples[index]
- sample = self.loader(path)
- if self.transform is not None:
- sample = self.transform(sample)
- if self.target_transform is not None:
- target = self.target_transform(target)
- return sample, target
- def __len__(self):
- return len(self.samples)
- def __repr__(self):
- fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
- fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
- fmt_str += ' Root Location: {}\n'.format(self.root)
- tmp = ' Transforms (if any): '
- fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
- tmp = ' Target Transforms (if any): '
- fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
- return fmt_str
- IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
- def pil_loader(path):
- # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
- if isinstance(path, bytes):
- img = Image.open(io.BytesIO(path))
- elif is_zip_path(path):
- data = ZipReader.read(path)
- img = Image.open(io.BytesIO(data))
- else:
- with open(path, 'rb') as f:
- img = Image.open(f)
- return img.convert('RGB')
- return img.convert('RGB')
- def accimage_loader(path):
- import accimage
- try:
- return accimage.Image(path)
- except IOError:
- # Potentially a decoding problem, fall back to PIL.Image
- return pil_loader(path)
- def default_img_loader(path):
- from torchvision import get_image_backend
- if get_image_backend() == 'accimage':
- return accimage_loader(path)
- else:
- return pil_loader(path)
- class CachedImageFolder(DatasetFolder):
- """A generic data loader where the images are arranged in this way: ::
- root/dog/xxx.png
- root/dog/xxy.png
- root/dog/xxz.png
- root/cat/123.png
- root/cat/nsdf3.png
- root/cat/asd932_.png
- Args:
- root (string): Root directory path.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- loader (callable, optional): A function to load an image given its path.
- Attributes:
- imgs (list): List of (image path, class_index) tuples
- """
- def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
- loader=default_img_loader, cache_mode="no"):
- super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
- ann_file=ann_file, img_prefix=img_prefix,
- transform=transform, target_transform=target_transform,
- cache_mode=cache_mode)
- self.imgs = self.samples
- def __getitem__(self, index):
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where target is class_index of the target class.
- """
- path, target = self.samples[index]
- image = self.loader(path)
- if self.transform is not None:
- img = self.transform(image)
- else:
- img = image
- if self.target_transform is not None:
- target = self.target_transform(target)
- return img, target
|