3
0

swin_transformer.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. # --------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. import torch
  8. import torch.nn as nn
  9. import torch.utils.checkpoint as checkpoint
  10. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  11. try:
  12. import os, sys
  13. kernel_path = os.path.abspath(os.path.join('..'))
  14. sys.path.append(kernel_path)
  15. from kernels.window_process.window_process import WindowProcess, WindowProcessReverse
  16. except:
  17. WindowProcess = None
  18. WindowProcessReverse = None
  19. print("[Warning] Fused window process have not been installed. Please refer to get_started.md for installation.")
  20. class Mlp(nn.Module):
  21. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  22. super().__init__()
  23. out_features = out_features or in_features
  24. hidden_features = hidden_features or in_features
  25. self.fc1 = nn.Linear(in_features, hidden_features)
  26. self.act = act_layer()
  27. self.fc2 = nn.Linear(hidden_features, out_features)
  28. self.drop = nn.Dropout(drop)
  29. def forward(self, x):
  30. x = self.fc1(x)
  31. x = self.act(x)
  32. x = self.drop(x)
  33. x = self.fc2(x)
  34. x = self.drop(x)
  35. return x
  36. def window_partition(x, window_size):
  37. """
  38. Args:
  39. x: (B, H, W, C)
  40. window_size (int): window size
  41. Returns:
  42. windows: (num_windows*B, window_size, window_size, C)
  43. """
  44. B, H, W, C = x.shape
  45. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  46. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  47. return windows
  48. def window_reverse(windows, window_size, H, W):
  49. """
  50. Args:
  51. windows: (num_windows*B, window_size, window_size, C)
  52. window_size (int): Window size
  53. H (int): Height of image
  54. W (int): Width of image
  55. Returns:
  56. x: (B, H, W, C)
  57. """
  58. B = int(windows.shape[0] / (H * W / window_size / window_size))
  59. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
  60. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  61. return x
  62. class WindowAttention(nn.Module):
  63. r""" Window based multi-head self attention (W-MSA) module with relative position bias.
  64. It supports both of shifted and non-shifted window.
  65. Args:
  66. dim (int): Number of input channels.
  67. window_size (tuple[int]): The height and width of the window.
  68. num_heads (int): Number of attention heads.
  69. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  70. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
  71. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  72. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  73. """
  74. def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
  75. super().__init__()
  76. self.dim = dim
  77. self.window_size = window_size # Wh, Ww
  78. self.num_heads = num_heads
  79. head_dim = dim // num_heads
  80. self.scale = qk_scale or head_dim ** -0.5
  81. # define a parameter table of relative position bias
  82. self.relative_position_bias_table = nn.Parameter(
  83. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
  84. # get pair-wise relative position index for each token inside the window
  85. coords_h = torch.arange(self.window_size[0])
  86. coords_w = torch.arange(self.window_size[1])
  87. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
  88. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  89. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  90. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  91. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
  92. relative_coords[:, :, 1] += self.window_size[1] - 1
  93. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  94. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  95. self.register_buffer("relative_position_index", relative_position_index)
  96. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  97. self.attn_drop = nn.Dropout(attn_drop)
  98. self.proj = nn.Linear(dim, dim)
  99. self.proj_drop = nn.Dropout(proj_drop)
  100. trunc_normal_(self.relative_position_bias_table, std=.02)
  101. self.softmax = nn.Softmax(dim=-1)
  102. def forward(self, x, mask=None):
  103. """
  104. Args:
  105. x: input features with shape of (num_windows*B, N, C)
  106. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
  107. """
  108. B_, N, C = x.shape
  109. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  110. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  111. q = q * self.scale
  112. attn = (q @ k.transpose(-2, -1))
  113. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  114. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
  115. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  116. attn = attn + relative_position_bias.unsqueeze(0)
  117. if mask is not None:
  118. nW = mask.shape[0]
  119. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  120. attn = attn.view(-1, self.num_heads, N, N)
  121. attn = self.softmax(attn)
  122. else:
  123. attn = self.softmax(attn)
  124. attn = self.attn_drop(attn)
  125. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
  126. x = self.proj(x)
  127. x = self.proj_drop(x)
  128. return x
  129. def extra_repr(self) -> str:
  130. return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
  131. def flops(self, N):
  132. # calculate flops for 1 window with token length of N
  133. flops = 0
  134. # qkv = self.qkv(x)
  135. flops += N * self.dim * 3 * self.dim
  136. # attn = (q @ k.transpose(-2, -1))
  137. flops += self.num_heads * N * (self.dim // self.num_heads) * N
  138. # x = (attn @ v)
  139. flops += self.num_heads * N * N * (self.dim // self.num_heads)
  140. # x = self.proj(x)
  141. flops += N * self.dim * self.dim
  142. return flops
  143. class SwinTransformerBlock(nn.Module):
  144. r""" Swin Transformer Block.
  145. Args:
  146. dim (int): Number of input channels.
  147. input_resolution (tuple[int]): Input resulotion.
  148. num_heads (int): Number of attention heads.
  149. window_size (int): Window size.
  150. shift_size (int): Shift size for SW-MSA.
  151. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  152. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  153. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  154. drop (float, optional): Dropout rate. Default: 0.0
  155. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  156. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  157. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
  158. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  159. fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
  160. """
  161. def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
  162. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
  163. act_layer=nn.GELU, norm_layer=nn.LayerNorm,
  164. fused_window_process=False):
  165. super().__init__()
  166. self.dim = dim
  167. self.input_resolution = input_resolution
  168. self.num_heads = num_heads
  169. self.window_size = window_size
  170. self.shift_size = shift_size
  171. self.mlp_ratio = mlp_ratio
  172. if min(self.input_resolution) <= self.window_size:
  173. # if window size is larger than input resolution, we don't partition windows
  174. self.shift_size = 0
  175. self.window_size = min(self.input_resolution)
  176. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
  177. self.norm1 = norm_layer(dim)
  178. self.attn = WindowAttention(
  179. dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
  180. qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  181. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  182. self.norm2 = norm_layer(dim)
  183. mlp_hidden_dim = int(dim * mlp_ratio)
  184. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  185. if self.shift_size > 0:
  186. # calculate attention mask for SW-MSA
  187. H, W = self.input_resolution
  188. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
  189. h_slices = (slice(0, -self.window_size),
  190. slice(-self.window_size, -self.shift_size),
  191. slice(-self.shift_size, None))
  192. w_slices = (slice(0, -self.window_size),
  193. slice(-self.window_size, -self.shift_size),
  194. slice(-self.shift_size, None))
  195. cnt = 0
  196. for h in h_slices:
  197. for w in w_slices:
  198. img_mask[:, h, w, :] = cnt
  199. cnt += 1
  200. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
  201. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  202. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  203. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  204. else:
  205. attn_mask = None
  206. self.register_buffer("attn_mask", attn_mask)
  207. self.fused_window_process = fused_window_process
  208. def forward(self, x):
  209. H, W = self.input_resolution
  210. B, L, C = x.shape
  211. assert L == H * W, "input feature has wrong size"
  212. shortcut = x
  213. x = self.norm1(x)
  214. x = x.view(B, H, W, C)
  215. # cyclic shift
  216. if self.shift_size > 0:
  217. if not self.fused_window_process:
  218. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  219. # partition windows
  220. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
  221. else:
  222. x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
  223. else:
  224. shifted_x = x
  225. # partition windows
  226. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
  227. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
  228. # W-MSA/SW-MSA
  229. attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
  230. # merge windows
  231. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  232. # reverse cyclic shift
  233. if self.shift_size > 0:
  234. if not self.fused_window_process:
  235. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
  236. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  237. else:
  238. x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
  239. else:
  240. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
  241. x = shifted_x
  242. x = x.view(B, H * W, C)
  243. x = shortcut + self.drop_path(x)
  244. # FFN
  245. x = x + self.drop_path(self.mlp(self.norm2(x)))
  246. return x
  247. def extra_repr(self) -> str:
  248. return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
  249. f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
  250. def flops(self):
  251. flops = 0
  252. H, W = self.input_resolution
  253. # norm1
  254. flops += self.dim * H * W
  255. # W-MSA/SW-MSA
  256. nW = H * W / self.window_size / self.window_size
  257. flops += nW * self.attn.flops(self.window_size * self.window_size)
  258. # mlp
  259. flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
  260. # norm2
  261. flops += self.dim * H * W
  262. return flops
  263. class PatchMerging(nn.Module):
  264. r""" Patch Merging Layer.
  265. Args:
  266. input_resolution (tuple[int]): Resolution of input feature.
  267. dim (int): Number of input channels.
  268. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  269. """
  270. def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
  271. super().__init__()
  272. self.input_resolution = input_resolution
  273. self.dim = dim
  274. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  275. self.norm = norm_layer(4 * dim)
  276. def forward(self, x):
  277. """
  278. x: B, H*W, C
  279. """
  280. H, W = self.input_resolution
  281. B, L, C = x.shape
  282. assert L == H * W, "input feature has wrong size"
  283. assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
  284. x = x.view(B, H, W, C)
  285. x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
  286. x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
  287. x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
  288. x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
  289. x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
  290. x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
  291. x = self.norm(x)
  292. x = self.reduction(x)
  293. return x
  294. def extra_repr(self) -> str:
  295. return f"input_resolution={self.input_resolution}, dim={self.dim}"
  296. def flops(self):
  297. H, W = self.input_resolution
  298. flops = H * W * self.dim
  299. flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
  300. return flops
  301. class BasicLayer(nn.Module):
  302. """ A basic Swin Transformer layer for one stage.
  303. Args:
  304. dim (int): Number of input channels.
  305. input_resolution (tuple[int]): Input resolution.
  306. depth (int): Number of blocks.
  307. num_heads (int): Number of attention heads.
  308. window_size (int): Local window size.
  309. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  310. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  311. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  312. drop (float, optional): Dropout rate. Default: 0.0
  313. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  314. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  315. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  316. downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
  317. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  318. fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
  319. """
  320. def __init__(self, dim, input_resolution, depth, num_heads, window_size,
  321. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
  322. drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
  323. fused_window_process=False):
  324. super().__init__()
  325. self.dim = dim
  326. self.input_resolution = input_resolution
  327. self.depth = depth
  328. self.use_checkpoint = use_checkpoint
  329. # build blocks
  330. self.blocks = nn.ModuleList([
  331. SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
  332. num_heads=num_heads, window_size=window_size,
  333. shift_size=0 if (i % 2 == 0) else window_size // 2,
  334. mlp_ratio=mlp_ratio,
  335. qkv_bias=qkv_bias, qk_scale=qk_scale,
  336. drop=drop, attn_drop=attn_drop,
  337. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  338. norm_layer=norm_layer,
  339. fused_window_process=fused_window_process)
  340. for i in range(depth)])
  341. # patch merging layer
  342. if downsample is not None:
  343. self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
  344. else:
  345. self.downsample = None
  346. def forward(self, x):
  347. for blk in self.blocks:
  348. if self.use_checkpoint:
  349. x = checkpoint.checkpoint(blk, x)
  350. else:
  351. x = blk(x)
  352. if self.downsample is not None:
  353. x = self.downsample(x)
  354. return x
  355. def extra_repr(self) -> str:
  356. return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
  357. def flops(self):
  358. flops = 0
  359. for blk in self.blocks:
  360. flops += blk.flops()
  361. if self.downsample is not None:
  362. flops += self.downsample.flops()
  363. return flops
  364. class PatchEmbed(nn.Module):
  365. r""" Image to Patch Embedding
  366. Args:
  367. img_size (int): Image size. Default: 224.
  368. patch_size (int): Patch token size. Default: 4.
  369. in_chans (int): Number of input image channels. Default: 3.
  370. embed_dim (int): Number of linear projection output channels. Default: 96.
  371. norm_layer (nn.Module, optional): Normalization layer. Default: None
  372. """
  373. def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
  374. super().__init__()
  375. img_size = to_2tuple(img_size)
  376. patch_size = to_2tuple(patch_size)
  377. patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
  378. self.img_size = img_size
  379. self.patch_size = patch_size
  380. self.patches_resolution = patches_resolution
  381. self.num_patches = patches_resolution[0] * patches_resolution[1]
  382. self.in_chans = in_chans
  383. self.embed_dim = embed_dim
  384. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  385. if norm_layer is not None:
  386. self.norm = norm_layer(embed_dim)
  387. else:
  388. self.norm = None
  389. def forward(self, x):
  390. B, C, H, W = x.shape
  391. # FIXME look at relaxing size constraints
  392. assert H == self.img_size[0] and W == self.img_size[1], \
  393. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  394. x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
  395. if self.norm is not None:
  396. x = self.norm(x)
  397. return x
  398. def flops(self):
  399. Ho, Wo = self.patches_resolution
  400. flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
  401. if self.norm is not None:
  402. flops += Ho * Wo * self.embed_dim
  403. return flops
  404. class SwinTransformer(nn.Module):
  405. r""" Swin Transformer
  406. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  407. https://arxiv.org/pdf/2103.14030
  408. Args:
  409. img_size (int | tuple(int)): Input image size. Default 224
  410. patch_size (int | tuple(int)): Patch size. Default: 4
  411. in_chans (int): Number of input image channels. Default: 3
  412. num_classes (int): Number of classes for classification head. Default: 1000
  413. embed_dim (int): Patch embedding dimension. Default: 96
  414. depths (tuple(int)): Depth of each Swin Transformer layer.
  415. num_heads (tuple(int)): Number of attention heads in different layers.
  416. window_size (int): Window size. Default: 7
  417. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  418. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  419. qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
  420. drop_rate (float): Dropout rate. Default: 0
  421. attn_drop_rate (float): Attention dropout rate. Default: 0
  422. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  423. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  424. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
  425. patch_norm (bool): If True, add normalization after patch embedding. Default: True
  426. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
  427. fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
  428. """
  429. def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
  430. embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
  431. window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
  432. drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
  433. norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
  434. use_checkpoint=False, fused_window_process=False, **kwargs):
  435. super().__init__()
  436. self.num_classes = num_classes
  437. self.num_layers = len(depths)
  438. self.embed_dim = embed_dim
  439. self.ape = ape
  440. self.patch_norm = patch_norm
  441. self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
  442. self.mlp_ratio = mlp_ratio
  443. # split image into non-overlapping patches
  444. self.patch_embed = PatchEmbed(
  445. img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
  446. norm_layer=norm_layer if self.patch_norm else None)
  447. num_patches = self.patch_embed.num_patches
  448. patches_resolution = self.patch_embed.patches_resolution
  449. self.patches_resolution = patches_resolution
  450. # absolute position embedding
  451. if self.ape:
  452. self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
  453. trunc_normal_(self.absolute_pos_embed, std=.02)
  454. self.pos_drop = nn.Dropout(p=drop_rate)
  455. # stochastic depth
  456. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  457. # build layers
  458. self.layers = nn.ModuleList()
  459. for i_layer in range(self.num_layers):
  460. layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
  461. input_resolution=(patches_resolution[0] // (2 ** i_layer),
  462. patches_resolution[1] // (2 ** i_layer)),
  463. depth=depths[i_layer],
  464. num_heads=num_heads[i_layer],
  465. window_size=window_size,
  466. mlp_ratio=self.mlp_ratio,
  467. qkv_bias=qkv_bias, qk_scale=qk_scale,
  468. drop=drop_rate, attn_drop=attn_drop_rate,
  469. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  470. norm_layer=norm_layer,
  471. downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
  472. use_checkpoint=use_checkpoint,
  473. fused_window_process=fused_window_process)
  474. self.layers.append(layer)
  475. self.norm = norm_layer(self.num_features)
  476. self.avgpool = nn.AdaptiveAvgPool1d(1)
  477. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  478. self.apply(self._init_weights)
  479. def _init_weights(self, m):
  480. if isinstance(m, nn.Linear):
  481. trunc_normal_(m.weight, std=.02)
  482. if isinstance(m, nn.Linear) and m.bias is not None:
  483. nn.init.constant_(m.bias, 0)
  484. elif isinstance(m, nn.LayerNorm):
  485. nn.init.constant_(m.bias, 0)
  486. nn.init.constant_(m.weight, 1.0)
  487. @torch.jit.ignore
  488. def no_weight_decay(self):
  489. return {'absolute_pos_embed'}
  490. @torch.jit.ignore
  491. def no_weight_decay_keywords(self):
  492. return {'relative_position_bias_table'}
  493. def forward_features(self, x):
  494. x = self.patch_embed(x)
  495. if self.ape:
  496. x = x + self.absolute_pos_embed
  497. x = self.pos_drop(x)
  498. for layer in self.layers:
  499. x = layer(x)
  500. x = self.norm(x) # B L C
  501. x = self.avgpool(x.transpose(1, 2)) # B C 1
  502. x = torch.flatten(x, 1)
  503. return x
  504. def forward(self, x):
  505. x = self.forward_features(x)
  506. x = self.head(x)
  507. return x
  508. def flops(self):
  509. flops = 0
  510. flops += self.patch_embed.flops()
  511. for i, layer in enumerate(self.layers):
  512. flops += layer.flops()
  513. flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
  514. flops += self.num_features * self.num_classes
  515. return flops