data_simmim_ft.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # --------------------------------------------------------
  2. # SimMIM
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Zhenda Xie
  6. # --------------------------------------------------------
  7. import os
  8. import torch.distributed as dist
  9. from torch.utils.data import DataLoader, DistributedSampler
  10. from torchvision import datasets, transforms
  11. from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  12. from timm.data import Mixup
  13. from timm.data import create_transform
  14. from timm.data.transforms import _pil_interp
  15. def build_loader_finetune(config):
  16. config.defrost()
  17. dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
  18. config.freeze()
  19. dataset_val, _ = build_dataset(is_train=False, config=config)
  20. num_tasks = dist.get_world_size()
  21. global_rank = dist.get_rank()
  22. sampler_train = DistributedSampler(
  23. dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
  24. )
  25. sampler_val = DistributedSampler(
  26. dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False
  27. )
  28. data_loader_train = DataLoader(
  29. dataset_train, sampler=sampler_train,
  30. batch_size=config.DATA.BATCH_SIZE,
  31. num_workers=config.DATA.NUM_WORKERS,
  32. pin_memory=config.DATA.PIN_MEMORY,
  33. drop_last=True,
  34. )
  35. data_loader_val = DataLoader(
  36. dataset_val, sampler=sampler_val,
  37. batch_size=config.DATA.BATCH_SIZE,
  38. num_workers=config.DATA.NUM_WORKERS,
  39. pin_memory=config.DATA.PIN_MEMORY,
  40. drop_last=False,
  41. )
  42. # setup mixup / cutmix
  43. mixup_fn = None
  44. mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
  45. if mixup_active:
  46. mixup_fn = Mixup(
  47. mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
  48. prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
  49. label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
  50. return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
  51. def build_dataset(is_train, config):
  52. transform = build_transform(is_train, config)
  53. if config.DATA.DATASET == 'imagenet':
  54. prefix = 'train' if is_train else 'val'
  55. root = os.path.join(config.DATA.DATA_PATH, prefix)
  56. dataset = datasets.ImageFolder(root, transform=transform)
  57. nb_classes = 1000
  58. else:
  59. raise NotImplementedError("We only support ImageNet Now.")
  60. return dataset, nb_classes
  61. def build_transform(is_train, config):
  62. resize_im = config.DATA.IMG_SIZE > 32
  63. if is_train:
  64. # this should always dispatch to transforms_imagenet_train
  65. transform = create_transform(
  66. input_size=config.DATA.IMG_SIZE,
  67. is_training=True,
  68. color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
  69. auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
  70. re_prob=config.AUG.REPROB,
  71. re_mode=config.AUG.REMODE,
  72. re_count=config.AUG.RECOUNT,
  73. interpolation=config.DATA.INTERPOLATION,
  74. )
  75. if not resize_im:
  76. # replace RandomResizedCropAndInterpolation with
  77. # RandomCrop
  78. transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
  79. return transform
  80. t = []
  81. if resize_im:
  82. if config.TEST.CROP:
  83. size = int((256 / 224) * config.DATA.IMG_SIZE)
  84. t.append(
  85. transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
  86. # to maintain same ratio w.r.t. 224 images
  87. )
  88. t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
  89. else:
  90. t.append(
  91. transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
  92. interpolation=_pil_interp(config.DATA.INTERPOLATION))
  93. )
  94. t.append(transforms.ToTensor())
  95. t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
  96. return transforms.Compose(t)