123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- # --------------------------------------------------------
- # SimMIM
- # Copyright (c) 2021 Microsoft
- # Licensed under The MIT License [see LICENSE for details]
- # Written by Zhenda Xie
- # --------------------------------------------------------
- from functools import partial
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from timm.models.layers import trunc_normal_
- from .swin_transformer import SwinTransformer
- from .swin_transformer_v2 import SwinTransformerV2
- def norm_targets(targets, patch_size):
- assert patch_size % 2 == 1
-
- targets_ = targets
- targets_count = torch.ones_like(targets)
- targets_square = targets ** 2.
-
- targets_mean = F.avg_pool2d(targets, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False)
- targets_square_mean = F.avg_pool2d(targets_square, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False)
- targets_count = F.avg_pool2d(targets_count, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=True) * (patch_size ** 2)
-
- targets_var = (targets_square_mean - targets_mean ** 2.) * (targets_count / (targets_count - 1))
- targets_var = torch.clamp(targets_var, min=0.)
-
- targets_ = (targets_ - targets_mean) / (targets_var + 1.e-6) ** 0.5
-
- return targets_
- class SwinTransformerForSimMIM(SwinTransformer):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- assert self.num_classes == 0
- self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
- trunc_normal_(self.mask_token, mean=0., std=.02)
- def forward(self, x, mask):
- x = self.patch_embed(x)
- assert mask is not None
- B, L, _ = x.shape
- mask_tokens = self.mask_token.expand(B, L, -1)
- w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
- x = x * (1. - w) + mask_tokens * w
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
- for layer in self.layers:
- x = layer(x)
- x = self.norm(x)
- x = x.transpose(1, 2)
- B, C, L = x.shape
- H = W = int(L ** 0.5)
- x = x.reshape(B, C, H, W)
- return x
- @torch.jit.ignore
- def no_weight_decay(self):
- return super().no_weight_decay() | {'mask_token'}
- class SwinTransformerV2ForSimMIM(SwinTransformerV2):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- assert self.num_classes == 0
- self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
- trunc_normal_(self.mask_token, mean=0., std=.02)
- def forward(self, x, mask):
- x = self.patch_embed(x)
- assert mask is not None
- B, L, _ = x.shape
- mask_tokens = self.mask_token.expand(B, L, -1)
- w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
- x = x * (1. - w) + mask_tokens * w
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
- for layer in self.layers:
- x = layer(x)
- x = self.norm(x)
- x = x.transpose(1, 2)
- B, C, L = x.shape
- H = W = int(L ** 0.5)
- x = x.reshape(B, C, H, W)
- return x
- @torch.jit.ignore
- def no_weight_decay(self):
- return super().no_weight_decay() | {'mask_token'}
- class SimMIM(nn.Module):
- def __init__(self, config, encoder, encoder_stride, in_chans, patch_size):
- super().__init__()
- self.config = config
- self.encoder = encoder
- self.encoder_stride = encoder_stride
- self.decoder = nn.Sequential(
- nn.Conv2d(
- in_channels=self.encoder.num_features,
- out_channels=self.encoder_stride ** 2 * 3, kernel_size=1),
- nn.PixelShuffle(self.encoder_stride),
- )
- self.in_chans = in_chans
- self.patch_size = patch_size
- def forward(self, x, mask):
- z = self.encoder(x, mask)
- x_rec = self.decoder(z)
- mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
-
- # norm target as prompted
- if self.config.NORM_TARGET.ENABLE:
- x = norm_targets(x, self.config.NORM_TARGET.PATCH_SIZE)
-
- loss_recon = F.l1_loss(x, x_rec, reduction='none')
- loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans
- return loss
- @torch.jit.ignore
- def no_weight_decay(self):
- if hasattr(self.encoder, 'no_weight_decay'):
- return {'encoder.' + i for i in self.encoder.no_weight_decay()}
- return {}
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- if hasattr(self.encoder, 'no_weight_decay_keywords'):
- return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()}
- return {}
- def build_simmim(config):
- model_type = config.MODEL.TYPE
- if model_type == 'swin':
- encoder = SwinTransformerForSimMIM(
- img_size=config.DATA.IMG_SIZE,
- patch_size=config.MODEL.SWIN.PATCH_SIZE,
- in_chans=config.MODEL.SWIN.IN_CHANS,
- num_classes=0,
- embed_dim=config.MODEL.SWIN.EMBED_DIM,
- depths=config.MODEL.SWIN.DEPTHS,
- num_heads=config.MODEL.SWIN.NUM_HEADS,
- window_size=config.MODEL.SWIN.WINDOW_SIZE,
- mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
- qkv_bias=config.MODEL.SWIN.QKV_BIAS,
- qk_scale=config.MODEL.SWIN.QK_SCALE,
- drop_rate=config.MODEL.DROP_RATE,
- drop_path_rate=config.MODEL.DROP_PATH_RATE,
- ape=config.MODEL.SWIN.APE,
- patch_norm=config.MODEL.SWIN.PATCH_NORM,
- use_checkpoint=config.TRAIN.USE_CHECKPOINT)
- encoder_stride = 32
- in_chans = config.MODEL.SWIN.IN_CHANS
- patch_size = config.MODEL.SWIN.PATCH_SIZE
- elif model_type == 'swinv2':
- encoder = SwinTransformerV2ForSimMIM(
- img_size=config.DATA.IMG_SIZE,
- patch_size=config.MODEL.SWINV2.PATCH_SIZE,
- in_chans=config.MODEL.SWINV2.IN_CHANS,
- num_classes=0,
- embed_dim=config.MODEL.SWINV2.EMBED_DIM,
- depths=config.MODEL.SWINV2.DEPTHS,
- num_heads=config.MODEL.SWINV2.NUM_HEADS,
- window_size=config.MODEL.SWINV2.WINDOW_SIZE,
- mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
- qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
- drop_rate=config.MODEL.DROP_RATE,
- drop_path_rate=config.MODEL.DROP_PATH_RATE,
- ape=config.MODEL.SWINV2.APE,
- patch_norm=config.MODEL.SWINV2.PATCH_NORM,
- use_checkpoint=config.TRAIN.USE_CHECKPOINT)
- encoder_stride = 32
- in_chans = config.MODEL.SWINV2.IN_CHANS
- patch_size = config.MODEL.SWINV2.PATCH_SIZE
- else:
- raise NotImplementedError(f"Unknown pre-train model: {model_type}")
- model = SimMIM(config=config.MODEL.SIMMIM, encoder=encoder, encoder_stride=encoder_stride, in_chans=in_chans, patch_size=patch_size)
- return model
|