# -------------------------------------------------------- # SimMIM # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Zhenda Xie # -------------------------------------------------------- from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import trunc_normal_ from .swin_transformer import SwinTransformer from .swin_transformer_v2 import SwinTransformerV2 def norm_targets(targets, patch_size): assert patch_size % 2 == 1 targets_ = targets targets_count = torch.ones_like(targets) targets_square = targets ** 2. targets_mean = F.avg_pool2d(targets, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False) targets_square_mean = F.avg_pool2d(targets_square, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False) targets_count = F.avg_pool2d(targets_count, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=True) * (patch_size ** 2) targets_var = (targets_square_mean - targets_mean ** 2.) * (targets_count / (targets_count - 1)) targets_var = torch.clamp(targets_var, min=0.) targets_ = (targets_ - targets_mean) / (targets_var + 1.e-6) ** 0.5 return targets_ class SwinTransformerForSimMIM(SwinTransformer): def __init__(self, **kwargs): super().__init__(**kwargs) assert self.num_classes == 0 self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) trunc_normal_(self.mask_token, mean=0., std=.02) def forward(self, x, mask): x = self.patch_embed(x) assert mask is not None B, L, _ = x.shape mask_tokens = self.mask_token.expand(B, L, -1) w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) x = x * (1. - w) + mask_tokens * w if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) x = x.transpose(1, 2) B, C, L = x.shape H = W = int(L ** 0.5) x = x.reshape(B, C, H, W) return x @torch.jit.ignore def no_weight_decay(self): return super().no_weight_decay() | {'mask_token'} class SwinTransformerV2ForSimMIM(SwinTransformerV2): def __init__(self, **kwargs): super().__init__(**kwargs) assert self.num_classes == 0 self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) trunc_normal_(self.mask_token, mean=0., std=.02) def forward(self, x, mask): x = self.patch_embed(x) assert mask is not None B, L, _ = x.shape mask_tokens = self.mask_token.expand(B, L, -1) w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) x = x * (1. - w) + mask_tokens * w if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) x = x.transpose(1, 2) B, C, L = x.shape H = W = int(L ** 0.5) x = x.reshape(B, C, H, W) return x @torch.jit.ignore def no_weight_decay(self): return super().no_weight_decay() | {'mask_token'} class SimMIM(nn.Module): def __init__(self, config, encoder, encoder_stride, in_chans, patch_size): super().__init__() self.config = config self.encoder = encoder self.encoder_stride = encoder_stride self.decoder = nn.Sequential( nn.Conv2d( in_channels=self.encoder.num_features, out_channels=self.encoder_stride ** 2 * 3, kernel_size=1), nn.PixelShuffle(self.encoder_stride), ) self.in_chans = in_chans self.patch_size = patch_size def forward(self, x, mask): z = self.encoder(x, mask) x_rec = self.decoder(z) mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous() # norm target as prompted if self.config.NORM_TARGET.ENABLE: x = norm_targets(x, self.config.NORM_TARGET.PATCH_SIZE) loss_recon = F.l1_loss(x, x_rec, reduction='none') loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans return loss @torch.jit.ignore def no_weight_decay(self): if hasattr(self.encoder, 'no_weight_decay'): return {'encoder.' + i for i in self.encoder.no_weight_decay()} return {} @torch.jit.ignore def no_weight_decay_keywords(self): if hasattr(self.encoder, 'no_weight_decay_keywords'): return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()} return {} def build_simmim(config): model_type = config.MODEL.TYPE if model_type == 'swin': encoder = SwinTransformerForSimMIM( img_size=config.DATA.IMG_SIZE, patch_size=config.MODEL.SWIN.PATCH_SIZE, in_chans=config.MODEL.SWIN.IN_CHANS, num_classes=0, embed_dim=config.MODEL.SWIN.EMBED_DIM, depths=config.MODEL.SWIN.DEPTHS, num_heads=config.MODEL.SWIN.NUM_HEADS, window_size=config.MODEL.SWIN.WINDOW_SIZE, mlp_ratio=config.MODEL.SWIN.MLP_RATIO, qkv_bias=config.MODEL.SWIN.QKV_BIAS, qk_scale=config.MODEL.SWIN.QK_SCALE, drop_rate=config.MODEL.DROP_RATE, drop_path_rate=config.MODEL.DROP_PATH_RATE, ape=config.MODEL.SWIN.APE, patch_norm=config.MODEL.SWIN.PATCH_NORM, use_checkpoint=config.TRAIN.USE_CHECKPOINT) encoder_stride = 32 in_chans = config.MODEL.SWIN.IN_CHANS patch_size = config.MODEL.SWIN.PATCH_SIZE elif model_type == 'swinv2': encoder = SwinTransformerV2ForSimMIM( img_size=config.DATA.IMG_SIZE, patch_size=config.MODEL.SWINV2.PATCH_SIZE, in_chans=config.MODEL.SWINV2.IN_CHANS, num_classes=0, embed_dim=config.MODEL.SWINV2.EMBED_DIM, depths=config.MODEL.SWINV2.DEPTHS, num_heads=config.MODEL.SWINV2.NUM_HEADS, window_size=config.MODEL.SWINV2.WINDOW_SIZE, mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, qkv_bias=config.MODEL.SWINV2.QKV_BIAS, drop_rate=config.MODEL.DROP_RATE, drop_path_rate=config.MODEL.DROP_PATH_RATE, ape=config.MODEL.SWINV2.APE, patch_norm=config.MODEL.SWINV2.PATCH_NORM, use_checkpoint=config.TRAIN.USE_CHECKPOINT) encoder_stride = 32 in_chans = config.MODEL.SWINV2.IN_CHANS patch_size = config.MODEL.SWINV2.PATCH_SIZE else: raise NotImplementedError(f"Unknown pre-train model: {model_type}") model = SimMIM(config=config.MODEL.SIMMIM, encoder=encoder, encoder_stride=encoder_stride, in_chans=in_chans, patch_size=patch_size) return model