build.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # --------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. import os
  8. import torch
  9. import numpy as np
  10. import torch.distributed as dist
  11. from torchvision import datasets, transforms
  12. from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  13. from timm.data import Mixup
  14. from timm.data import create_transform
  15. from .cached_image_folder import CachedImageFolder
  16. from .imagenet22k_dataset import IN22KDATASET
  17. from .samplers import SubsetRandomSampler
  18. try:
  19. from torchvision.transforms import InterpolationMode
  20. def _pil_interp(method):
  21. if method == 'bicubic':
  22. return InterpolationMode.BICUBIC
  23. elif method == 'lanczos':
  24. return InterpolationMode.LANCZOS
  25. elif method == 'hamming':
  26. return InterpolationMode.HAMMING
  27. else:
  28. # default bilinear, do we want to allow nearest?
  29. return InterpolationMode.BILINEAR
  30. import timm.data.transforms as timm_transforms
  31. timm_transforms._pil_interp = _pil_interp
  32. except:
  33. from timm.data.transforms import _pil_interp
  34. def build_loader(config):
  35. config.defrost()
  36. dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
  37. config.freeze()
  38. print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
  39. dataset_val, _ = build_dataset(is_train=False, config=config)
  40. print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
  41. num_tasks = dist.get_world_size()
  42. global_rank = dist.get_rank()
  43. if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
  44. indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
  45. sampler_train = SubsetRandomSampler(indices)
  46. else:
  47. sampler_train = torch.utils.data.DistributedSampler(
  48. dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
  49. )
  50. if config.TEST.SEQUENTIAL:
  51. sampler_val = torch.utils.data.SequentialSampler(dataset_val)
  52. else:
  53. sampler_val = torch.utils.data.distributed.DistributedSampler(
  54. dataset_val, shuffle=config.TEST.SHUFFLE
  55. )
  56. data_loader_train = torch.utils.data.DataLoader(
  57. dataset_train, sampler=sampler_train,
  58. batch_size=config.DATA.BATCH_SIZE,
  59. num_workers=config.DATA.NUM_WORKERS,
  60. pin_memory=config.DATA.PIN_MEMORY,
  61. drop_last=True,
  62. )
  63. data_loader_val = torch.utils.data.DataLoader(
  64. dataset_val, sampler=sampler_val,
  65. batch_size=config.DATA.BATCH_SIZE,
  66. shuffle=False,
  67. num_workers=config.DATA.NUM_WORKERS,
  68. pin_memory=config.DATA.PIN_MEMORY,
  69. drop_last=False
  70. )
  71. # setup mixup / cutmix
  72. mixup_fn = None
  73. mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
  74. if mixup_active:
  75. mixup_fn = Mixup(
  76. mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
  77. prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
  78. label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
  79. return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
  80. def build_dataset(is_train, config):
  81. transform = build_transform(is_train, config)
  82. if config.DATA.DATASET == 'imagenet':
  83. # prefix = 'train' if is_train else 'val'
  84. prefix = 'train' if is_train else 'test'
  85. if config.DATA.ZIP_MODE:
  86. ann_file = prefix + "_map.txt"
  87. prefix = prefix + ".zip@/"
  88. dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
  89. cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
  90. else:
  91. root = os.path.join(config.DATA.DATA_PATH, prefix)
  92. dataset = datasets.ImageFolder(root, transform=transform)
  93. # nb_classes = 1000
  94. nb_classes = 2
  95. elif config.DATA.DATASET == 'imagenet22K':
  96. prefix = 'ILSVRC2011fall_whole'
  97. if is_train:
  98. ann_file = prefix + "_map_train.txt"
  99. else:
  100. ann_file = prefix + "_map_val.txt"
  101. dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform)
  102. nb_classes = 21841
  103. else:
  104. raise NotImplementedError("We only support ImageNet Now.")
  105. return dataset, nb_classes
  106. def build_transform(is_train, config):
  107. resize_im = config.DATA.IMG_SIZE > 32
  108. if is_train:
  109. # this should always dispatch to transforms_imagenet_train
  110. transform = create_transform(
  111. input_size=config.DATA.IMG_SIZE,
  112. is_training=True,
  113. color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
  114. auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
  115. re_prob=config.AUG.REPROB,
  116. re_mode=config.AUG.REMODE,
  117. re_count=config.AUG.RECOUNT,
  118. interpolation=config.DATA.INTERPOLATION,
  119. )
  120. if not resize_im:
  121. # replace RandomResizedCropAndInterpolation with
  122. # RandomCrop
  123. transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
  124. return transform
  125. t = []
  126. if resize_im:
  127. if config.TEST.CROP:
  128. size = int((256 / 224) * config.DATA.IMG_SIZE)
  129. t.append(
  130. transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
  131. # to maintain same ratio w.r.t. 224 images
  132. )
  133. t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
  134. else:
  135. t.append(
  136. transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
  137. interpolation=_pil_interp(config.DATA.INTERPOLATION))
  138. )
  139. t.append(transforms.ToTensor())
  140. t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
  141. return transforms.Compose(t)