data_simmim_pt.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # --------------------------------------------------------
  2. # SimMIM
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Zhenda Xie
  6. # --------------------------------------------------------
  7. import math
  8. import random
  9. import numpy as np
  10. import torch
  11. import torch.distributed as dist
  12. import torchvision.transforms as T
  13. from torch.utils.data import DataLoader, DistributedSampler
  14. from torch.utils.data._utils.collate import default_collate
  15. from torchvision.datasets import ImageFolder
  16. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  17. class MaskGenerator:
  18. def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
  19. self.input_size = input_size
  20. self.mask_patch_size = mask_patch_size
  21. self.model_patch_size = model_patch_size
  22. self.mask_ratio = mask_ratio
  23. assert self.input_size % self.mask_patch_size == 0
  24. assert self.mask_patch_size % self.model_patch_size == 0
  25. self.rand_size = self.input_size // self.mask_patch_size
  26. self.scale = self.mask_patch_size // self.model_patch_size
  27. self.token_count = self.rand_size ** 2
  28. self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
  29. def __call__(self):
  30. mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
  31. mask = np.zeros(self.token_count, dtype=int)
  32. mask[mask_idx] = 1
  33. mask = mask.reshape((self.rand_size, self.rand_size))
  34. mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
  35. return mask
  36. class SimMIMTransform:
  37. def __init__(self, config):
  38. self.transform_img = T.Compose([
  39. T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
  40. T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)),
  41. T.RandomHorizontalFlip(),
  42. T.ToTensor(),
  43. T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),
  44. ])
  45. if config.MODEL.TYPE in ['swin', 'swinv2']:
  46. model_patch_size=config.MODEL.SWIN.PATCH_SIZE
  47. else:
  48. raise NotImplementedError
  49. self.mask_generator = MaskGenerator(
  50. input_size=config.DATA.IMG_SIZE,
  51. mask_patch_size=config.DATA.MASK_PATCH_SIZE,
  52. model_patch_size=model_patch_size,
  53. mask_ratio=config.DATA.MASK_RATIO,
  54. )
  55. def __call__(self, img):
  56. img = self.transform_img(img)
  57. mask = self.mask_generator()
  58. return img, mask
  59. def collate_fn(batch):
  60. if not isinstance(batch[0][0], tuple):
  61. return default_collate(batch)
  62. else:
  63. batch_num = len(batch)
  64. ret = []
  65. for item_idx in range(len(batch[0][0])):
  66. if batch[0][0][item_idx] is None:
  67. ret.append(None)
  68. else:
  69. ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
  70. ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
  71. return ret
  72. def build_loader_simmim(config):
  73. transform = SimMIMTransform(config)
  74. dataset = ImageFolder(config.DATA.DATA_PATH, transform)
  75. sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
  76. dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
  77. return dataloader