swin_transformer_v2.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. # --------------------------------------------------------
  2. # Swin Transformer V2
  3. # Copyright (c) 2022 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import torch.utils.checkpoint as checkpoint
  11. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  12. import numpy as np
  13. class Mlp(nn.Module):
  14. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  15. super().__init__()
  16. out_features = out_features or in_features
  17. hidden_features = hidden_features or in_features
  18. self.fc1 = nn.Linear(in_features, hidden_features)
  19. self.act = act_layer()
  20. self.fc2 = nn.Linear(hidden_features, out_features)
  21. self.drop = nn.Dropout(drop)
  22. def forward(self, x):
  23. x = self.fc1(x)
  24. x = self.act(x)
  25. x = self.drop(x)
  26. x = self.fc2(x)
  27. x = self.drop(x)
  28. return x
  29. def window_partition(x, window_size):
  30. """
  31. Args:
  32. x: (B, H, W, C)
  33. window_size (int): window size
  34. Returns:
  35. windows: (num_windows*B, window_size, window_size, C)
  36. """
  37. B, H, W, C = x.shape
  38. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  39. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  40. return windows
  41. def window_reverse(windows, window_size, H, W):
  42. """
  43. Args:
  44. windows: (num_windows*B, window_size, window_size, C)
  45. window_size (int): Window size
  46. H (int): Height of image
  47. W (int): Width of image
  48. Returns:
  49. x: (B, H, W, C)
  50. """
  51. B = int(windows.shape[0] / (H * W / window_size / window_size))
  52. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
  53. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  54. return x
  55. class WindowAttention(nn.Module):
  56. r""" Window based multi-head self attention (W-MSA) module with relative position bias.
  57. It supports both of shifted and non-shifted window.
  58. Args:
  59. dim (int): Number of input channels.
  60. window_size (tuple[int]): The height and width of the window.
  61. num_heads (int): Number of attention heads.
  62. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  63. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  64. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  65. pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
  66. """
  67. def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
  68. pretrained_window_size=[0, 0]):
  69. super().__init__()
  70. self.dim = dim
  71. self.window_size = window_size # Wh, Ww
  72. self.pretrained_window_size = pretrained_window_size
  73. self.num_heads = num_heads
  74. self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
  75. # mlp to generate continuous relative position bias
  76. self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
  77. nn.ReLU(inplace=True),
  78. nn.Linear(512, num_heads, bias=False))
  79. # get relative_coords_table
  80. relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
  81. relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
  82. relative_coords_table = torch.stack(
  83. torch.meshgrid([relative_coords_h,
  84. relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
  85. if pretrained_window_size[0] > 0:
  86. relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
  87. relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
  88. else:
  89. relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
  90. relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
  91. relative_coords_table *= 8 # normalize to -8, 8
  92. relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
  93. torch.abs(relative_coords_table) + 1.0) / np.log2(8)
  94. self.register_buffer("relative_coords_table", relative_coords_table)
  95. # get pair-wise relative position index for each token inside the window
  96. coords_h = torch.arange(self.window_size[0])
  97. coords_w = torch.arange(self.window_size[1])
  98. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
  99. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  100. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  101. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  102. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
  103. relative_coords[:, :, 1] += self.window_size[1] - 1
  104. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  105. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  106. self.register_buffer("relative_position_index", relative_position_index)
  107. self.qkv = nn.Linear(dim, dim * 3, bias=False)
  108. if qkv_bias:
  109. self.q_bias = nn.Parameter(torch.zeros(dim))
  110. self.v_bias = nn.Parameter(torch.zeros(dim))
  111. else:
  112. self.q_bias = None
  113. self.v_bias = None
  114. self.attn_drop = nn.Dropout(attn_drop)
  115. self.proj = nn.Linear(dim, dim)
  116. self.proj_drop = nn.Dropout(proj_drop)
  117. self.softmax = nn.Softmax(dim=-1)
  118. def forward(self, x, mask=None):
  119. """
  120. Args:
  121. x: input features with shape of (num_windows*B, N, C)
  122. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
  123. """
  124. B_, N, C = x.shape
  125. qkv_bias = None
  126. if self.q_bias is not None:
  127. qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
  128. qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
  129. qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  130. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  131. # cosine attention
  132. attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
  133. logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
  134. attn = attn * logit_scale
  135. relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
  136. relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
  137. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
  138. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  139. relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
  140. attn = attn + relative_position_bias.unsqueeze(0)
  141. if mask is not None:
  142. nW = mask.shape[0]
  143. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  144. attn = attn.view(-1, self.num_heads, N, N)
  145. attn = self.softmax(attn)
  146. else:
  147. attn = self.softmax(attn)
  148. attn = self.attn_drop(attn)
  149. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
  150. x = self.proj(x)
  151. x = self.proj_drop(x)
  152. return x
  153. def extra_repr(self) -> str:
  154. return f'dim={self.dim}, window_size={self.window_size}, ' \
  155. f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
  156. def flops(self, N):
  157. # calculate flops for 1 window with token length of N
  158. flops = 0
  159. # qkv = self.qkv(x)
  160. flops += N * self.dim * 3 * self.dim
  161. # attn = (q @ k.transpose(-2, -1))
  162. flops += self.num_heads * N * (self.dim // self.num_heads) * N
  163. # x = (attn @ v)
  164. flops += self.num_heads * N * N * (self.dim // self.num_heads)
  165. # x = self.proj(x)
  166. flops += N * self.dim * self.dim
  167. return flops
  168. class SwinTransformerBlock(nn.Module):
  169. r""" Swin Transformer Block.
  170. Args:
  171. dim (int): Number of input channels.
  172. input_resolution (tuple[int]): Input resulotion.
  173. num_heads (int): Number of attention heads.
  174. window_size (int): Window size.
  175. shift_size (int): Shift size for SW-MSA.
  176. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  177. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  178. drop (float, optional): Dropout rate. Default: 0.0
  179. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  180. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  181. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
  182. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  183. pretrained_window_size (int): Window size in pre-training.
  184. """
  185. def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
  186. mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
  187. act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
  188. super().__init__()
  189. self.dim = dim
  190. self.input_resolution = input_resolution
  191. self.num_heads = num_heads
  192. self.window_size = window_size
  193. self.shift_size = shift_size
  194. self.mlp_ratio = mlp_ratio
  195. if min(self.input_resolution) <= self.window_size:
  196. # if window size is larger than input resolution, we don't partition windows
  197. self.shift_size = 0
  198. self.window_size = min(self.input_resolution)
  199. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
  200. self.norm1 = norm_layer(dim)
  201. self.attn = WindowAttention(
  202. dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
  203. qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
  204. pretrained_window_size=to_2tuple(pretrained_window_size))
  205. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  206. self.norm2 = norm_layer(dim)
  207. mlp_hidden_dim = int(dim * mlp_ratio)
  208. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  209. if self.shift_size > 0:
  210. # calculate attention mask for SW-MSA
  211. H, W = self.input_resolution
  212. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
  213. h_slices = (slice(0, -self.window_size),
  214. slice(-self.window_size, -self.shift_size),
  215. slice(-self.shift_size, None))
  216. w_slices = (slice(0, -self.window_size),
  217. slice(-self.window_size, -self.shift_size),
  218. slice(-self.shift_size, None))
  219. cnt = 0
  220. for h in h_slices:
  221. for w in w_slices:
  222. img_mask[:, h, w, :] = cnt
  223. cnt += 1
  224. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
  225. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  226. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  227. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  228. else:
  229. attn_mask = None
  230. self.register_buffer("attn_mask", attn_mask)
  231. def forward(self, x):
  232. H, W = self.input_resolution
  233. B, L, C = x.shape
  234. assert L == H * W, "input feature has wrong size"
  235. shortcut = x
  236. x = x.view(B, H, W, C)
  237. # cyclic shift
  238. if self.shift_size > 0:
  239. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  240. else:
  241. shifted_x = x
  242. # partition windows
  243. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
  244. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
  245. # W-MSA/SW-MSA
  246. attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
  247. # merge windows
  248. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  249. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
  250. # reverse cyclic shift
  251. if self.shift_size > 0:
  252. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  253. else:
  254. x = shifted_x
  255. x = x.view(B, H * W, C)
  256. x = shortcut + self.drop_path(self.norm1(x))
  257. # FFN
  258. x = x + self.drop_path(self.norm2(self.mlp(x)))
  259. return x
  260. def extra_repr(self) -> str:
  261. return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
  262. f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
  263. def flops(self):
  264. flops = 0
  265. H, W = self.input_resolution
  266. # norm1
  267. flops += self.dim * H * W
  268. # W-MSA/SW-MSA
  269. nW = H * W / self.window_size / self.window_size
  270. flops += nW * self.attn.flops(self.window_size * self.window_size)
  271. # mlp
  272. flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
  273. # norm2
  274. flops += self.dim * H * W
  275. return flops
  276. class PatchMerging(nn.Module):
  277. r""" Patch Merging Layer.
  278. Args:
  279. input_resolution (tuple[int]): Resolution of input feature.
  280. dim (int): Number of input channels.
  281. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  282. """
  283. def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
  284. super().__init__()
  285. self.input_resolution = input_resolution
  286. self.dim = dim
  287. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  288. self.norm = norm_layer(2 * dim)
  289. def forward(self, x):
  290. """
  291. x: B, H*W, C
  292. """
  293. H, W = self.input_resolution
  294. B, L, C = x.shape
  295. assert L == H * W, "input feature has wrong size"
  296. assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
  297. x = x.view(B, H, W, C)
  298. x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
  299. x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
  300. x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
  301. x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
  302. x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
  303. x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
  304. x = self.reduction(x)
  305. x = self.norm(x)
  306. return x
  307. def extra_repr(self) -> str:
  308. return f"input_resolution={self.input_resolution}, dim={self.dim}"
  309. def flops(self):
  310. H, W = self.input_resolution
  311. flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
  312. flops += H * W * self.dim // 2
  313. return flops
  314. class BasicLayer(nn.Module):
  315. """ A basic Swin Transformer layer for one stage.
  316. Args:
  317. dim (int): Number of input channels.
  318. input_resolution (tuple[int]): Input resolution.
  319. depth (int): Number of blocks.
  320. num_heads (int): Number of attention heads.
  321. window_size (int): Local window size.
  322. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  323. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  324. drop (float, optional): Dropout rate. Default: 0.0
  325. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  326. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  327. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  328. downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
  329. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  330. pretrained_window_size (int): Local window size in pre-training.
  331. """
  332. def __init__(self, dim, input_resolution, depth, num_heads, window_size,
  333. mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
  334. drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
  335. pretrained_window_size=0):
  336. super().__init__()
  337. self.dim = dim
  338. self.input_resolution = input_resolution
  339. self.depth = depth
  340. self.use_checkpoint = use_checkpoint
  341. # build blocks
  342. self.blocks = nn.ModuleList([
  343. SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
  344. num_heads=num_heads, window_size=window_size,
  345. shift_size=0 if (i % 2 == 0) else window_size // 2,
  346. mlp_ratio=mlp_ratio,
  347. qkv_bias=qkv_bias,
  348. drop=drop, attn_drop=attn_drop,
  349. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  350. norm_layer=norm_layer,
  351. pretrained_window_size=pretrained_window_size)
  352. for i in range(depth)])
  353. # patch merging layer
  354. if downsample is not None:
  355. self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
  356. else:
  357. self.downsample = None
  358. def forward(self, x):
  359. for blk in self.blocks:
  360. if self.use_checkpoint:
  361. x = checkpoint.checkpoint(blk, x)
  362. else:
  363. x = blk(x)
  364. if self.downsample is not None:
  365. x = self.downsample(x)
  366. return x
  367. def extra_repr(self) -> str:
  368. return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
  369. def flops(self):
  370. flops = 0
  371. for blk in self.blocks:
  372. flops += blk.flops()
  373. if self.downsample is not None:
  374. flops += self.downsample.flops()
  375. return flops
  376. def _init_respostnorm(self):
  377. for blk in self.blocks:
  378. nn.init.constant_(blk.norm1.bias, 0)
  379. nn.init.constant_(blk.norm1.weight, 0)
  380. nn.init.constant_(blk.norm2.bias, 0)
  381. nn.init.constant_(blk.norm2.weight, 0)
  382. class PatchEmbed(nn.Module):
  383. r""" Image to Patch Embedding
  384. Args:
  385. img_size (int): Image size. Default: 224.
  386. patch_size (int): Patch token size. Default: 4.
  387. in_chans (int): Number of input image channels. Default: 3.
  388. embed_dim (int): Number of linear projection output channels. Default: 96.
  389. norm_layer (nn.Module, optional): Normalization layer. Default: None
  390. """
  391. def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
  392. super().__init__()
  393. img_size = to_2tuple(img_size)
  394. patch_size = to_2tuple(patch_size)
  395. patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
  396. self.img_size = img_size
  397. self.patch_size = patch_size
  398. self.patches_resolution = patches_resolution
  399. self.num_patches = patches_resolution[0] * patches_resolution[1]
  400. self.in_chans = in_chans
  401. self.embed_dim = embed_dim
  402. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  403. if norm_layer is not None:
  404. self.norm = norm_layer(embed_dim)
  405. else:
  406. self.norm = None
  407. def forward(self, x):
  408. B, C, H, W = x.shape
  409. # FIXME look at relaxing size constraints
  410. assert H == self.img_size[0] and W == self.img_size[1], \
  411. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  412. x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
  413. if self.norm is not None:
  414. x = self.norm(x)
  415. return x
  416. def flops(self):
  417. Ho, Wo = self.patches_resolution
  418. flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
  419. if self.norm is not None:
  420. flops += Ho * Wo * self.embed_dim
  421. return flops
  422. class SwinTransformerV2(nn.Module):
  423. r""" Swin Transformer
  424. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  425. https://arxiv.org/pdf/2103.14030
  426. Args:
  427. img_size (int | tuple(int)): Input image size. Default 224
  428. patch_size (int | tuple(int)): Patch size. Default: 4
  429. in_chans (int): Number of input image channels. Default: 3
  430. num_classes (int): Number of classes for classification head. Default: 1000
  431. embed_dim (int): Patch embedding dimension. Default: 96
  432. depths (tuple(int)): Depth of each Swin Transformer layer.
  433. num_heads (tuple(int)): Number of attention heads in different layers.
  434. window_size (int): Window size. Default: 7
  435. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  436. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  437. drop_rate (float): Dropout rate. Default: 0
  438. attn_drop_rate (float): Attention dropout rate. Default: 0
  439. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  440. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  441. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
  442. patch_norm (bool): If True, add normalization after patch embedding. Default: True
  443. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
  444. pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
  445. """
  446. def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
  447. embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
  448. window_size=7, mlp_ratio=4., qkv_bias=True,
  449. drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
  450. norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
  451. use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0], **kwargs):
  452. super().__init__()
  453. self.num_classes = num_classes
  454. self.num_layers = len(depths)
  455. self.embed_dim = embed_dim
  456. self.ape = ape
  457. self.patch_norm = patch_norm
  458. self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
  459. self.mlp_ratio = mlp_ratio
  460. # split image into non-overlapping patches
  461. self.patch_embed = PatchEmbed(
  462. img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
  463. norm_layer=norm_layer if self.patch_norm else None)
  464. num_patches = self.patch_embed.num_patches
  465. patches_resolution = self.patch_embed.patches_resolution
  466. self.patches_resolution = patches_resolution
  467. # absolute position embedding
  468. if self.ape:
  469. self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
  470. trunc_normal_(self.absolute_pos_embed, std=.02)
  471. self.pos_drop = nn.Dropout(p=drop_rate)
  472. # stochastic depth
  473. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  474. # build layers
  475. self.layers = nn.ModuleList()
  476. for i_layer in range(self.num_layers):
  477. layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
  478. input_resolution=(patches_resolution[0] // (2 ** i_layer),
  479. patches_resolution[1] // (2 ** i_layer)),
  480. depth=depths[i_layer],
  481. num_heads=num_heads[i_layer],
  482. window_size=window_size,
  483. mlp_ratio=self.mlp_ratio,
  484. qkv_bias=qkv_bias,
  485. drop=drop_rate, attn_drop=attn_drop_rate,
  486. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  487. norm_layer=norm_layer,
  488. downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
  489. use_checkpoint=use_checkpoint,
  490. pretrained_window_size=pretrained_window_sizes[i_layer])
  491. self.layers.append(layer)
  492. self.norm = norm_layer(self.num_features)
  493. self.avgpool = nn.AdaptiveAvgPool1d(1)
  494. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  495. self.apply(self._init_weights)
  496. for bly in self.layers:
  497. bly._init_respostnorm()
  498. def _init_weights(self, m):
  499. if isinstance(m, nn.Linear):
  500. trunc_normal_(m.weight, std=.02)
  501. if isinstance(m, nn.Linear) and m.bias is not None:
  502. nn.init.constant_(m.bias, 0)
  503. elif isinstance(m, nn.LayerNorm):
  504. nn.init.constant_(m.bias, 0)
  505. nn.init.constant_(m.weight, 1.0)
  506. @torch.jit.ignore
  507. def no_weight_decay(self):
  508. return {'absolute_pos_embed'}
  509. @torch.jit.ignore
  510. def no_weight_decay_keywords(self):
  511. return {"cpb_mlp", "logit_scale", 'relative_position_bias_table'}
  512. def forward_features(self, x):
  513. x = self.patch_embed(x)
  514. if self.ape:
  515. x = x + self.absolute_pos_embed
  516. x = self.pos_drop(x)
  517. for layer in self.layers:
  518. x = layer(x)
  519. x = self.norm(x) # B L C
  520. x = self.avgpool(x.transpose(1, 2)) # B C 1
  521. x = torch.flatten(x, 1)
  522. return x
  523. def forward(self, x):
  524. x = self.forward_features(x)
  525. x = self.head(x)
  526. return x
  527. def flops(self):
  528. flops = 0
  529. flops += self.patch_embed.flops()
  530. for i, layer in enumerate(self.layers):
  531. flops += layer.flops()
  532. flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
  533. flops += self.num_features * self.num_classes
  534. return flops