3
0

cached_image_folder.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # --------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. import io
  8. import os
  9. import time
  10. import torch.distributed as dist
  11. import torch.utils.data as data
  12. from PIL import Image
  13. from .zipreader import is_zip_path, ZipReader
  14. def has_file_allowed_extension(filename, extensions):
  15. """Checks if a file is an allowed extension.
  16. Args:
  17. filename (string): path to a file
  18. Returns:
  19. bool: True if the filename ends with a known image extension
  20. """
  21. filename_lower = filename.lower()
  22. return any(filename_lower.endswith(ext) for ext in extensions)
  23. def find_classes(dir):
  24. classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
  25. classes.sort()
  26. class_to_idx = {classes[i]: i for i in range(len(classes))}
  27. return classes, class_to_idx
  28. def make_dataset(dir, class_to_idx, extensions):
  29. images = []
  30. dir = os.path.expanduser(dir)
  31. for target in sorted(os.listdir(dir)):
  32. d = os.path.join(dir, target)
  33. if not os.path.isdir(d):
  34. continue
  35. for root, _, fnames in sorted(os.walk(d)):
  36. for fname in sorted(fnames):
  37. if has_file_allowed_extension(fname, extensions):
  38. path = os.path.join(root, fname)
  39. item = (path, class_to_idx[target])
  40. images.append(item)
  41. return images
  42. def make_dataset_with_ann(ann_file, img_prefix, extensions):
  43. images = []
  44. with open(ann_file, "r") as f:
  45. contents = f.readlines()
  46. for line_str in contents:
  47. path_contents = [c for c in line_str.split('\t')]
  48. im_file_name = path_contents[0]
  49. class_index = int(path_contents[1])
  50. assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
  51. item = (os.path.join(img_prefix, im_file_name), class_index)
  52. images.append(item)
  53. return images
  54. class DatasetFolder(data.Dataset):
  55. """A generic data loader where the samples are arranged in this way: ::
  56. root/class_x/xxx.ext
  57. root/class_x/xxy.ext
  58. root/class_x/xxz.ext
  59. root/class_y/123.ext
  60. root/class_y/nsdf3.ext
  61. root/class_y/asd932_.ext
  62. Args:
  63. root (string): Root directory path.
  64. loader (callable): A function to load a sample given its path.
  65. extensions (list[string]): A list of allowed extensions.
  66. transform (callable, optional): A function/transform that takes in
  67. a sample and returns a transformed version.
  68. E.g, ``transforms.RandomCrop`` for images.
  69. target_transform (callable, optional): A function/transform that takes
  70. in the target and transforms it.
  71. Attributes:
  72. samples (list): List of (sample path, class_index) tuples
  73. """
  74. def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
  75. cache_mode="no"):
  76. # image folder mode
  77. if ann_file == '':
  78. _, class_to_idx = find_classes(root)
  79. samples = make_dataset(root, class_to_idx, extensions)
  80. # zip mode
  81. else:
  82. samples = make_dataset_with_ann(os.path.join(root, ann_file),
  83. os.path.join(root, img_prefix),
  84. extensions)
  85. if len(samples) == 0:
  86. raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
  87. "Supported extensions are: " + ",".join(extensions)))
  88. self.root = root
  89. self.loader = loader
  90. self.extensions = extensions
  91. self.samples = samples
  92. self.labels = [y_1k for _, y_1k in samples]
  93. self.classes = list(set(self.labels))
  94. self.transform = transform
  95. self.target_transform = target_transform
  96. self.cache_mode = cache_mode
  97. if self.cache_mode != "no":
  98. self.init_cache()
  99. def init_cache(self):
  100. assert self.cache_mode in ["part", "full"]
  101. n_sample = len(self.samples)
  102. global_rank = dist.get_rank()
  103. world_size = dist.get_world_size()
  104. samples_bytes = [None for _ in range(n_sample)]
  105. start_time = time.time()
  106. for index in range(n_sample):
  107. if index % (n_sample // 10) == 0:
  108. t = time.time() - start_time
  109. print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
  110. start_time = time.time()
  111. path, target = self.samples[index]
  112. if self.cache_mode == "full":
  113. samples_bytes[index] = (ZipReader.read(path), target)
  114. elif self.cache_mode == "part" and index % world_size == global_rank:
  115. samples_bytes[index] = (ZipReader.read(path), target)
  116. else:
  117. samples_bytes[index] = (path, target)
  118. self.samples = samples_bytes
  119. def __getitem__(self, index):
  120. """
  121. Args:
  122. index (int): Index
  123. Returns:
  124. tuple: (sample, target) where target is class_index of the target class.
  125. """
  126. path, target = self.samples[index]
  127. sample = self.loader(path)
  128. if self.transform is not None:
  129. sample = self.transform(sample)
  130. if self.target_transform is not None:
  131. target = self.target_transform(target)
  132. return sample, target
  133. def __len__(self):
  134. return len(self.samples)
  135. def __repr__(self):
  136. fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
  137. fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
  138. fmt_str += ' Root Location: {}\n'.format(self.root)
  139. tmp = ' Transforms (if any): '
  140. fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  141. tmp = ' Target Transforms (if any): '
  142. fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  143. return fmt_str
  144. IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
  145. def pil_loader(path):
  146. # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
  147. if isinstance(path, bytes):
  148. img = Image.open(io.BytesIO(path))
  149. elif is_zip_path(path):
  150. data = ZipReader.read(path)
  151. img = Image.open(io.BytesIO(data))
  152. else:
  153. with open(path, 'rb') as f:
  154. img = Image.open(f)
  155. return img.convert('RGB')
  156. return img.convert('RGB')
  157. def accimage_loader(path):
  158. import accimage
  159. try:
  160. return accimage.Image(path)
  161. except IOError:
  162. # Potentially a decoding problem, fall back to PIL.Image
  163. return pil_loader(path)
  164. def default_img_loader(path):
  165. from torchvision import get_image_backend
  166. if get_image_backend() == 'accimage':
  167. return accimage_loader(path)
  168. else:
  169. return pil_loader(path)
  170. class CachedImageFolder(DatasetFolder):
  171. """A generic data loader where the images are arranged in this way: ::
  172. root/dog/xxx.png
  173. root/dog/xxy.png
  174. root/dog/xxz.png
  175. root/cat/123.png
  176. root/cat/nsdf3.png
  177. root/cat/asd932_.png
  178. Args:
  179. root (string): Root directory path.
  180. transform (callable, optional): A function/transform that takes in an PIL image
  181. and returns a transformed version. E.g, ``transforms.RandomCrop``
  182. target_transform (callable, optional): A function/transform that takes in the
  183. target and transforms it.
  184. loader (callable, optional): A function to load an image given its path.
  185. Attributes:
  186. imgs (list): List of (image path, class_index) tuples
  187. """
  188. def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
  189. loader=default_img_loader, cache_mode="no"):
  190. super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
  191. ann_file=ann_file, img_prefix=img_prefix,
  192. transform=transform, target_transform=target_transform,
  193. cache_mode=cache_mode)
  194. self.imgs = self.samples
  195. def __getitem__(self, index):
  196. """
  197. Args:
  198. index (int): Index
  199. Returns:
  200. tuple: (image, target) where target is class_index of the target class.
  201. """
  202. path, target = self.samples[index]
  203. image = self.loader(path)
  204. if self.transform is not None:
  205. img = self.transform(image)
  206. else:
  207. img = image
  208. if self.target_transform is not None:
  209. target = self.target_transform(target)
  210. return img, target