simmim.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # --------------------------------------------------------
  2. # SimMIM
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Zhenda Xie
  6. # --------------------------------------------------------
  7. from functools import partial
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from timm.models.layers import trunc_normal_
  12. from .swin_transformer import SwinTransformer
  13. from .swin_transformer_v2 import SwinTransformerV2
  14. def norm_targets(targets, patch_size):
  15. assert patch_size % 2 == 1
  16. targets_ = targets
  17. targets_count = torch.ones_like(targets)
  18. targets_square = targets ** 2.
  19. targets_mean = F.avg_pool2d(targets, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False)
  20. targets_square_mean = F.avg_pool2d(targets_square, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False)
  21. targets_count = F.avg_pool2d(targets_count, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=True) * (patch_size ** 2)
  22. targets_var = (targets_square_mean - targets_mean ** 2.) * (targets_count / (targets_count - 1))
  23. targets_var = torch.clamp(targets_var, min=0.)
  24. targets_ = (targets_ - targets_mean) / (targets_var + 1.e-6) ** 0.5
  25. return targets_
  26. class SwinTransformerForSimMIM(SwinTransformer):
  27. def __init__(self, **kwargs):
  28. super().__init__(**kwargs)
  29. assert self.num_classes == 0
  30. self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
  31. trunc_normal_(self.mask_token, mean=0., std=.02)
  32. def forward(self, x, mask):
  33. x = self.patch_embed(x)
  34. assert mask is not None
  35. B, L, _ = x.shape
  36. mask_tokens = self.mask_token.expand(B, L, -1)
  37. w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
  38. x = x * (1. - w) + mask_tokens * w
  39. if self.ape:
  40. x = x + self.absolute_pos_embed
  41. x = self.pos_drop(x)
  42. for layer in self.layers:
  43. x = layer(x)
  44. x = self.norm(x)
  45. x = x.transpose(1, 2)
  46. B, C, L = x.shape
  47. H = W = int(L ** 0.5)
  48. x = x.reshape(B, C, H, W)
  49. return x
  50. @torch.jit.ignore
  51. def no_weight_decay(self):
  52. return super().no_weight_decay() | {'mask_token'}
  53. class SwinTransformerV2ForSimMIM(SwinTransformerV2):
  54. def __init__(self, **kwargs):
  55. super().__init__(**kwargs)
  56. assert self.num_classes == 0
  57. self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
  58. trunc_normal_(self.mask_token, mean=0., std=.02)
  59. def forward(self, x, mask):
  60. x = self.patch_embed(x)
  61. assert mask is not None
  62. B, L, _ = x.shape
  63. mask_tokens = self.mask_token.expand(B, L, -1)
  64. w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
  65. x = x * (1. - w) + mask_tokens * w
  66. if self.ape:
  67. x = x + self.absolute_pos_embed
  68. x = self.pos_drop(x)
  69. for layer in self.layers:
  70. x = layer(x)
  71. x = self.norm(x)
  72. x = x.transpose(1, 2)
  73. B, C, L = x.shape
  74. H = W = int(L ** 0.5)
  75. x = x.reshape(B, C, H, W)
  76. return x
  77. @torch.jit.ignore
  78. def no_weight_decay(self):
  79. return super().no_weight_decay() | {'mask_token'}
  80. class SimMIM(nn.Module):
  81. def __init__(self, config, encoder, encoder_stride, in_chans, patch_size):
  82. super().__init__()
  83. self.config = config
  84. self.encoder = encoder
  85. self.encoder_stride = encoder_stride
  86. self.decoder = nn.Sequential(
  87. nn.Conv2d(
  88. in_channels=self.encoder.num_features,
  89. out_channels=self.encoder_stride ** 2 * 3, kernel_size=1),
  90. nn.PixelShuffle(self.encoder_stride),
  91. )
  92. self.in_chans = in_chans
  93. self.patch_size = patch_size
  94. def forward(self, x, mask):
  95. z = self.encoder(x, mask)
  96. x_rec = self.decoder(z)
  97. mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
  98. # norm target as prompted
  99. if self.config.NORM_TARGET.ENABLE:
  100. x = norm_targets(x, self.config.NORM_TARGET.PATCH_SIZE)
  101. loss_recon = F.l1_loss(x, x_rec, reduction='none')
  102. loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans
  103. return loss
  104. @torch.jit.ignore
  105. def no_weight_decay(self):
  106. if hasattr(self.encoder, 'no_weight_decay'):
  107. return {'encoder.' + i for i in self.encoder.no_weight_decay()}
  108. return {}
  109. @torch.jit.ignore
  110. def no_weight_decay_keywords(self):
  111. if hasattr(self.encoder, 'no_weight_decay_keywords'):
  112. return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()}
  113. return {}
  114. def build_simmim(config):
  115. model_type = config.MODEL.TYPE
  116. if model_type == 'swin':
  117. encoder = SwinTransformerForSimMIM(
  118. img_size=config.DATA.IMG_SIZE,
  119. patch_size=config.MODEL.SWIN.PATCH_SIZE,
  120. in_chans=config.MODEL.SWIN.IN_CHANS,
  121. num_classes=0,
  122. embed_dim=config.MODEL.SWIN.EMBED_DIM,
  123. depths=config.MODEL.SWIN.DEPTHS,
  124. num_heads=config.MODEL.SWIN.NUM_HEADS,
  125. window_size=config.MODEL.SWIN.WINDOW_SIZE,
  126. mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
  127. qkv_bias=config.MODEL.SWIN.QKV_BIAS,
  128. qk_scale=config.MODEL.SWIN.QK_SCALE,
  129. drop_rate=config.MODEL.DROP_RATE,
  130. drop_path_rate=config.MODEL.DROP_PATH_RATE,
  131. ape=config.MODEL.SWIN.APE,
  132. patch_norm=config.MODEL.SWIN.PATCH_NORM,
  133. use_checkpoint=config.TRAIN.USE_CHECKPOINT)
  134. encoder_stride = 32
  135. in_chans = config.MODEL.SWIN.IN_CHANS
  136. patch_size = config.MODEL.SWIN.PATCH_SIZE
  137. elif model_type == 'swinv2':
  138. encoder = SwinTransformerV2ForSimMIM(
  139. img_size=config.DATA.IMG_SIZE,
  140. patch_size=config.MODEL.SWINV2.PATCH_SIZE,
  141. in_chans=config.MODEL.SWINV2.IN_CHANS,
  142. num_classes=0,
  143. embed_dim=config.MODEL.SWINV2.EMBED_DIM,
  144. depths=config.MODEL.SWINV2.DEPTHS,
  145. num_heads=config.MODEL.SWINV2.NUM_HEADS,
  146. window_size=config.MODEL.SWINV2.WINDOW_SIZE,
  147. mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
  148. qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
  149. drop_rate=config.MODEL.DROP_RATE,
  150. drop_path_rate=config.MODEL.DROP_PATH_RATE,
  151. ape=config.MODEL.SWINV2.APE,
  152. patch_norm=config.MODEL.SWINV2.PATCH_NORM,
  153. use_checkpoint=config.TRAIN.USE_CHECKPOINT)
  154. encoder_stride = 32
  155. in_chans = config.MODEL.SWINV2.IN_CHANS
  156. patch_size = config.MODEL.SWINV2.PATCH_SIZE
  157. else:
  158. raise NotImplementedError(f"Unknown pre-train model: {model_type}")
  159. model = SimMIM(config=config.MODEL.SIMMIM, encoder=encoder, encoder_stride=encoder_stride, in_chans=in_chans, patch_size=patch_size)
  160. return model