123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- # --------------------------------------------------------
- # SimMIM
- # Copyright (c) 2021 Microsoft
- # Licensed under The MIT License [see LICENSE for details]
- # Written by Zhenda Xie
- # --------------------------------------------------------
- import math
- import random
- import numpy as np
- import torch
- import torch.distributed as dist
- import torchvision.transforms as T
- from torch.utils.data import DataLoader, DistributedSampler
- from torch.utils.data._utils.collate import default_collate
- from torchvision.datasets import ImageFolder
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- class MaskGenerator:
- def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
- self.input_size = input_size
- self.mask_patch_size = mask_patch_size
- self.model_patch_size = model_patch_size
- self.mask_ratio = mask_ratio
-
- assert self.input_size % self.mask_patch_size == 0
- assert self.mask_patch_size % self.model_patch_size == 0
-
- self.rand_size = self.input_size // self.mask_patch_size
- self.scale = self.mask_patch_size // self.model_patch_size
-
- self.token_count = self.rand_size ** 2
- self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
-
- def __call__(self):
- mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
- mask = np.zeros(self.token_count, dtype=int)
- mask[mask_idx] = 1
-
- mask = mask.reshape((self.rand_size, self.rand_size))
- mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
-
- return mask
- class SimMIMTransform:
- def __init__(self, config):
- self.transform_img = T.Compose([
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
- T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)),
- T.RandomHorizontalFlip(),
- T.ToTensor(),
- T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),
- ])
-
- if config.MODEL.TYPE in ['swin', 'swinv2']:
- model_patch_size=config.MODEL.SWIN.PATCH_SIZE
- else:
- raise NotImplementedError
-
- self.mask_generator = MaskGenerator(
- input_size=config.DATA.IMG_SIZE,
- mask_patch_size=config.DATA.MASK_PATCH_SIZE,
- model_patch_size=model_patch_size,
- mask_ratio=config.DATA.MASK_RATIO,
- )
-
- def __call__(self, img):
- img = self.transform_img(img)
- mask = self.mask_generator()
-
- return img, mask
- def collate_fn(batch):
- if not isinstance(batch[0][0], tuple):
- return default_collate(batch)
- else:
- batch_num = len(batch)
- ret = []
- for item_idx in range(len(batch[0][0])):
- if batch[0][0][item_idx] is None:
- ret.append(None)
- else:
- ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
- ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
- return ret
- def build_loader_simmim(config):
- transform = SimMIMTransform(config)
- dataset = ImageFolder(config.DATA.DATA_PATH, transform)
-
- sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
- dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
-
- return dataloader
|