imagenet22k_dataset.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import os
  2. import json
  3. import torch.utils.data as data
  4. import numpy as np
  5. from PIL import Image
  6. import warnings
  7. warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
  8. class IN22KDATASET(data.Dataset):
  9. def __init__(self, root, ann_file='', transform=None, target_transform=None):
  10. super(IN22KDATASET, self).__init__()
  11. self.data_path = root
  12. self.ann_path = os.path.join(self.data_path, ann_file)
  13. self.transform = transform
  14. self.target_transform = target_transform
  15. # id & label: https://github.com/google-research/big_transfer/issues/7
  16. # total: 21843; only 21841 class have images: map 21841->9205; 21842->15027
  17. self.database = json.load(open(self.ann_path))
  18. def _load_image(self, path):
  19. try:
  20. im = Image.open(path)
  21. except:
  22. print("ERROR IMG LOADED: ", path)
  23. random_img = np.random.rand(224, 224, 3) * 255
  24. im = Image.fromarray(np.uint8(random_img))
  25. return im
  26. def __getitem__(self, index):
  27. """
  28. Args:
  29. index (int): Index
  30. Returns:
  31. tuple: (image, target) where target is class_index of the target class.
  32. """
  33. idb = self.database[index]
  34. # images
  35. images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB')
  36. if self.transform is not None:
  37. images = self.transform(images)
  38. # target
  39. target = int(idb[1])
  40. if self.target_transform is not None:
  41. target = self.target_transform(target)
  42. return images, target
  43. def __len__(self):
  44. return len(self.database)