3
0

swin_transformer_moe.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824
  1. # --------------------------------------------------------
  2. # Swin Transformer MoE
  3. # Copyright (c) 2022 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import torch.distributed as dist
  11. import torch.utils.checkpoint as checkpoint
  12. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  13. import numpy as np
  14. try:
  15. from tutel import moe as tutel_moe
  16. except:
  17. tutel_moe = None
  18. print("Tutel has not been installed. To use Swin-MoE, please install Tutel; otherwise, just ignore this.")
  19. class Mlp(nn.Module):
  20. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
  21. mlp_fc2_bias=True):
  22. super().__init__()
  23. out_features = out_features or in_features
  24. hidden_features = hidden_features or in_features
  25. self.fc1 = nn.Linear(in_features, hidden_features)
  26. self.act = act_layer()
  27. self.fc2 = nn.Linear(hidden_features, out_features, bias=mlp_fc2_bias)
  28. self.drop = nn.Dropout(drop)
  29. def forward(self, x):
  30. x = self.fc1(x)
  31. x = self.act(x)
  32. x = self.drop(x)
  33. x = self.fc2(x)
  34. x = self.drop(x)
  35. return x
  36. class MoEMlp(nn.Module):
  37. def __init__(self, in_features, hidden_features, num_local_experts, top_value, capacity_factor=1.25,
  38. cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True,
  39. gate_noise=1.0, cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0, init_std=0.02,
  40. mlp_fc2_bias=True):
  41. super().__init__()
  42. self.in_features = in_features
  43. self.hidden_features = hidden_features
  44. self.num_local_experts = num_local_experts
  45. self.top_value = top_value
  46. self.capacity_factor = capacity_factor
  47. self.cosine_router = cosine_router
  48. self.normalize_gate = normalize_gate
  49. self.use_bpr = use_bpr
  50. self.init_std = init_std
  51. self.mlp_fc2_bias = mlp_fc2_bias
  52. self.dist_rank = dist.get_rank()
  53. self._dropout = nn.Dropout(p=moe_drop)
  54. _gate_type = {'type': 'cosine_top' if cosine_router else 'top',
  55. 'k': top_value, 'capacity_factor': capacity_factor,
  56. 'gate_noise': gate_noise, 'fp32_gate': True}
  57. if cosine_router:
  58. _gate_type['proj_dim'] = cosine_router_dim
  59. _gate_type['init_t'] = cosine_router_init_t
  60. self._moe_layer = tutel_moe.moe_layer(
  61. gate_type=_gate_type,
  62. model_dim=in_features,
  63. experts={'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_features,
  64. 'activation_fn': lambda x: self._dropout(F.gelu(x))},
  65. scan_expert_func=lambda name, param: setattr(param, 'skip_allreduce', True),
  66. seeds=(1, self.dist_rank + 1, self.dist_rank + 1),
  67. batch_prioritized_routing=use_bpr,
  68. normalize_gate=normalize_gate,
  69. is_gshard_loss=is_gshard_loss,
  70. )
  71. if not self.mlp_fc2_bias:
  72. self._moe_layer.experts.batched_fc2_bias.requires_grad = False
  73. def forward(self, x):
  74. x = self._moe_layer(x)
  75. return x, x.l_aux
  76. def extra_repr(self) -> str:
  77. return f'[Statistics-{self.dist_rank}] param count for MoE, ' \
  78. f'in_features = {self.in_features}, hidden_features = {self.hidden_features}, ' \
  79. f'num_local_experts = {self.num_local_experts}, top_value = {self.top_value}, ' \
  80. f'cosine_router={self.cosine_router} normalize_gate={self.normalize_gate}, use_bpr = {self.use_bpr}'
  81. def _init_weights(self):
  82. if hasattr(self._moe_layer, "experts"):
  83. trunc_normal_(self._moe_layer.experts.batched_fc1_w, std=self.init_std)
  84. trunc_normal_(self._moe_layer.experts.batched_fc2_w, std=self.init_std)
  85. nn.init.constant_(self._moe_layer.experts.batched_fc1_bias, 0)
  86. nn.init.constant_(self._moe_layer.experts.batched_fc2_bias, 0)
  87. def window_partition(x, window_size):
  88. """
  89. Args:
  90. x: (B, H, W, C)
  91. window_size (int): window size
  92. Returns:
  93. windows: (num_windows*B, window_size, window_size, C)
  94. """
  95. B, H, W, C = x.shape
  96. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  97. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  98. return windows
  99. def window_reverse(windows, window_size, H, W):
  100. """
  101. Args:
  102. windows: (num_windows*B, window_size, window_size, C)
  103. window_size (int): Window size
  104. H (int): Height of image
  105. W (int): Width of image
  106. Returns:
  107. x: (B, H, W, C)
  108. """
  109. B = int(windows.shape[0] / (H * W / window_size / window_size))
  110. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
  111. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  112. return x
  113. class WindowAttention(nn.Module):
  114. r""" Window based multi-head self attention (W-MSA) module with relative position bias.
  115. It supports both of shifted and non-shifted window.
  116. Args:
  117. dim (int): Number of input channels.
  118. window_size (tuple[int]): The height and width of the window.
  119. num_heads (int): Number of attention heads.
  120. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  121. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
  122. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  123. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  124. pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
  125. """
  126. def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
  127. pretrained_window_size=[0, 0]):
  128. super().__init__()
  129. self.dim = dim
  130. self.window_size = window_size # Wh, Ww
  131. self.pretrained_window_size = pretrained_window_size
  132. self.num_heads = num_heads
  133. head_dim = dim // num_heads
  134. self.scale = qk_scale or head_dim ** -0.5
  135. # mlp to generate continuous relative position bias
  136. self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
  137. nn.ReLU(inplace=True),
  138. nn.Linear(512, num_heads, bias=False))
  139. # get relative_coords_table
  140. relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
  141. relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
  142. relative_coords_table = torch.stack(
  143. torch.meshgrid([relative_coords_h,
  144. relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
  145. if pretrained_window_size[0] > 0:
  146. relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
  147. relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
  148. else:
  149. relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
  150. relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
  151. relative_coords_table *= 8 # normalize to -8, 8
  152. relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
  153. torch.abs(relative_coords_table) + 1.0) / np.log2(8)
  154. self.register_buffer("relative_coords_table", relative_coords_table)
  155. # get pair-wise relative position index for each token inside the window
  156. coords_h = torch.arange(self.window_size[0])
  157. coords_w = torch.arange(self.window_size[1])
  158. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
  159. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  160. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  161. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  162. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
  163. relative_coords[:, :, 1] += self.window_size[1] - 1
  164. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  165. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  166. self.register_buffer("relative_position_index", relative_position_index)
  167. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  168. self.attn_drop = nn.Dropout(attn_drop)
  169. self.proj = nn.Linear(dim, dim)
  170. self.proj_drop = nn.Dropout(proj_drop)
  171. self.softmax = nn.Softmax(dim=-1)
  172. def forward(self, x, mask=None):
  173. """
  174. Args:
  175. x: input features with shape of (num_windows*B, N, C)
  176. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
  177. """
  178. B_, N, C = x.shape
  179. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  180. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  181. q = q * self.scale
  182. attn = (q @ k.transpose(-2, -1))
  183. relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
  184. relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
  185. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
  186. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  187. attn = attn + relative_position_bias.unsqueeze(0)
  188. if mask is not None:
  189. nW = mask.shape[0]
  190. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  191. attn = attn.view(-1, self.num_heads, N, N)
  192. attn = self.softmax(attn)
  193. else:
  194. attn = self.softmax(attn)
  195. attn = self.attn_drop(attn)
  196. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
  197. x = self.proj(x)
  198. x = self.proj_drop(x)
  199. return x
  200. def extra_repr(self) -> str:
  201. return f'dim={self.dim}, window_size={self.window_size}, ' \
  202. f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
  203. def flops(self, N):
  204. # calculate flops for 1 window with token length of N
  205. flops = 0
  206. # qkv = self.qkv(x)
  207. flops += N * self.dim * 3 * self.dim
  208. # attn = (q @ k.transpose(-2, -1))
  209. flops += self.num_heads * N * (self.dim // self.num_heads) * N
  210. # x = (attn @ v)
  211. flops += self.num_heads * N * N * (self.dim // self.num_heads)
  212. # x = self.proj(x)
  213. flops += N * self.dim * self.dim
  214. return flops
  215. class SwinTransformerBlock(nn.Module):
  216. r""" Swin Transformer Block.
  217. Args:
  218. dim (int): Number of input channels.
  219. input_resolution (tuple[int]): Input resulotion.
  220. num_heads (int): Number of attention heads.
  221. window_size (int): Window size.
  222. shift_size (int): Shift size for SW-MSA.
  223. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  224. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  225. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  226. drop (float, optional): Dropout rate. Default: 0.0
  227. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  228. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  229. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
  230. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  231. mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True
  232. init_std: Initialization std. Default: 0.02
  233. pretrained_window_size (int): Window size in pre-training.
  234. is_moe (bool): If True, this block is a MoE block.
  235. num_local_experts (int): number of local experts in each device (GPU). Default: 1
  236. top_value (int): the value of k in top-k gating. Default: 1
  237. capacity_factor (float): the capacity factor in MoE. Default: 1.25
  238. cosine_router (bool): Whether to use cosine router. Default: False
  239. normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False
  240. use_bpr (bool): Whether to use batch-prioritized-routing. Default: True
  241. is_gshard_loss (bool): If True, use Gshard balance loss.
  242. If False, use the load loss and importance loss in "arXiv:1701.06538". Default: False
  243. gate_noise (float): the noise ratio in top-k gating. Default: 1.0
  244. cosine_router_dim (int): Projection dimension in cosine router.
  245. cosine_router_init_t (float): Initialization temperature in cosine router.
  246. moe_drop (float): Dropout rate in MoE. Default: 0.0
  247. """
  248. def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
  249. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
  250. act_layer=nn.GELU, norm_layer=nn.LayerNorm, mlp_fc2_bias=True, init_std=0.02, pretrained_window_size=0,
  251. is_moe=False, num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False,
  252. normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0,
  253. cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0):
  254. super().__init__()
  255. self.dim = dim
  256. self.input_resolution = input_resolution
  257. self.num_heads = num_heads
  258. self.window_size = window_size
  259. self.shift_size = shift_size
  260. self.mlp_ratio = mlp_ratio
  261. self.is_moe = is_moe
  262. self.capacity_factor = capacity_factor
  263. self.top_value = top_value
  264. if min(self.input_resolution) <= self.window_size:
  265. # if window size is larger than input resolution, we don't partition windows
  266. self.shift_size = 0
  267. self.window_size = min(self.input_resolution)
  268. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
  269. self.norm1 = norm_layer(dim)
  270. self.attn = WindowAttention(
  271. dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
  272. qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
  273. pretrained_window_size=to_2tuple(pretrained_window_size))
  274. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  275. self.norm2 = norm_layer(dim)
  276. mlp_hidden_dim = int(dim * mlp_ratio)
  277. if self.is_moe:
  278. self.mlp = MoEMlp(in_features=dim,
  279. hidden_features=mlp_hidden_dim,
  280. num_local_experts=num_local_experts,
  281. top_value=top_value,
  282. capacity_factor=capacity_factor,
  283. cosine_router=cosine_router,
  284. normalize_gate=normalize_gate,
  285. use_bpr=use_bpr,
  286. is_gshard_loss=is_gshard_loss,
  287. gate_noise=gate_noise,
  288. cosine_router_dim=cosine_router_dim,
  289. cosine_router_init_t=cosine_router_init_t,
  290. moe_drop=moe_drop,
  291. mlp_fc2_bias=mlp_fc2_bias,
  292. init_std=init_std)
  293. else:
  294. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
  295. mlp_fc2_bias=mlp_fc2_bias)
  296. if self.shift_size > 0:
  297. # calculate attention mask for SW-MSA
  298. H, W = self.input_resolution
  299. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
  300. h_slices = (slice(0, -self.window_size),
  301. slice(-self.window_size, -self.shift_size),
  302. slice(-self.shift_size, None))
  303. w_slices = (slice(0, -self.window_size),
  304. slice(-self.window_size, -self.shift_size),
  305. slice(-self.shift_size, None))
  306. cnt = 0
  307. for h in h_slices:
  308. for w in w_slices:
  309. img_mask[:, h, w, :] = cnt
  310. cnt += 1
  311. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
  312. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  313. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  314. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  315. else:
  316. attn_mask = None
  317. self.register_buffer("attn_mask", attn_mask)
  318. def forward(self, x):
  319. H, W = self.input_resolution
  320. B, L, C = x.shape
  321. assert L == H * W, "input feature has wrong size"
  322. shortcut = x
  323. x = self.norm1(x)
  324. x = x.view(B, H, W, C)
  325. # cyclic shift
  326. if self.shift_size > 0:
  327. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  328. else:
  329. shifted_x = x
  330. # partition windows
  331. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
  332. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
  333. # W-MSA/SW-MSA
  334. attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
  335. # merge windows
  336. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  337. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
  338. # reverse cyclic shift
  339. if self.shift_size > 0:
  340. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  341. else:
  342. x = shifted_x
  343. x = x.view(B, H * W, C)
  344. x = shortcut + self.drop_path(x)
  345. # FFN
  346. shortcut = x
  347. x = self.norm2(x)
  348. if self.is_moe:
  349. x, l_aux = self.mlp(x)
  350. x = shortcut + self.drop_path(x)
  351. return x, l_aux
  352. else:
  353. x = shortcut + self.drop_path(self.mlp(x))
  354. return x
  355. def extra_repr(self) -> str:
  356. return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
  357. f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
  358. def flops(self):
  359. flops = 0
  360. H, W = self.input_resolution
  361. # norm1
  362. flops += self.dim * H * W
  363. # W-MSA/SW-MSA
  364. nW = H * W / self.window_size / self.window_size
  365. flops += nW * self.attn.flops(self.window_size * self.window_size)
  366. # mlp
  367. if self.is_moe:
  368. flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio * self.capacity_factor * self.top_value
  369. else:
  370. flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
  371. # norm2
  372. flops += self.dim * H * W
  373. return flops
  374. class PatchMerging(nn.Module):
  375. r""" Patch Merging Layer.
  376. Args:
  377. input_resolution (tuple[int]): Resolution of input feature.
  378. dim (int): Number of input channels.
  379. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  380. """
  381. def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
  382. super().__init__()
  383. self.input_resolution = input_resolution
  384. self.dim = dim
  385. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  386. self.norm = norm_layer(4 * dim)
  387. def forward(self, x):
  388. """
  389. x: B, H*W, C
  390. """
  391. H, W = self.input_resolution
  392. B, L, C = x.shape
  393. assert L == H * W, "input feature has wrong size"
  394. assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
  395. x = x.view(B, H, W, C)
  396. x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
  397. x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
  398. x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
  399. x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
  400. x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
  401. x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
  402. x = self.norm(x)
  403. x = self.reduction(x)
  404. return x
  405. def extra_repr(self) -> str:
  406. return f"input_resolution={self.input_resolution}, dim={self.dim}"
  407. def flops(self):
  408. H, W = self.input_resolution
  409. flops = H * W * self.dim
  410. flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
  411. return flops
  412. class BasicLayer(nn.Module):
  413. """ A basic Swin Transformer layer for one stage.
  414. Args:
  415. dim (int): Number of input channels.
  416. input_resolution (tuple[int]): Input resolution.
  417. depth (int): Number of blocks.
  418. num_heads (int): Number of attention heads.
  419. window_size (int): Local window size.
  420. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  421. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  422. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  423. drop (float, optional): Dropout rate. Default: 0.0
  424. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  425. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  426. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  427. downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
  428. mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True
  429. init_std: Initialization std. Default: 0.02
  430. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  431. pretrained_window_size (int): Local window size in pre-training.
  432. moe_blocks (tuple(int)): The index of each MoE block.
  433. num_local_experts (int): number of local experts in each device (GPU). Default: 1
  434. top_value (int): the value of k in top-k gating. Default: 1
  435. capacity_factor (float): the capacity factor in MoE. Default: 1.25
  436. cosine_router (bool): Whether to use cosine router Default: False
  437. normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False
  438. use_bpr (bool): Whether to use batch-prioritized-routing. Default: True
  439. is_gshard_loss (bool): If True, use Gshard balance loss.
  440. If False, use the load loss and importance loss in "arXiv:1701.06538". Default: False
  441. gate_noise (float): the noise ratio in top-k gating. Default: 1.0
  442. cosine_router_dim (int): Projection dimension in cosine router.
  443. cosine_router_init_t (float): Initialization temperature in cosine router.
  444. moe_drop (float): Dropout rate in MoE. Default: 0.0
  445. """
  446. def __init__(self, dim, input_resolution, depth, num_heads, window_size,
  447. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
  448. drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
  449. mlp_fc2_bias=True, init_std=0.02, use_checkpoint=False, pretrained_window_size=0,
  450. moe_block=[-1], num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False,
  451. normalize_gate=False, use_bpr=True, is_gshard_loss=True,
  452. cosine_router_dim=256, cosine_router_init_t=0.5, gate_noise=1.0, moe_drop=0.0):
  453. super().__init__()
  454. self.dim = dim
  455. self.input_resolution = input_resolution
  456. self.depth = depth
  457. self.use_checkpoint = use_checkpoint
  458. # build blocks
  459. self.blocks = nn.ModuleList([
  460. SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
  461. num_heads=num_heads, window_size=window_size,
  462. shift_size=0 if (i % 2 == 0) else window_size // 2,
  463. mlp_ratio=mlp_ratio,
  464. qkv_bias=qkv_bias, qk_scale=qk_scale,
  465. drop=drop, attn_drop=attn_drop,
  466. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  467. norm_layer=norm_layer,
  468. mlp_fc2_bias=mlp_fc2_bias,
  469. init_std=init_std,
  470. pretrained_window_size=pretrained_window_size,
  471. is_moe=True if i in moe_block else False,
  472. num_local_experts=num_local_experts,
  473. top_value=top_value,
  474. capacity_factor=capacity_factor,
  475. cosine_router=cosine_router,
  476. normalize_gate=normalize_gate,
  477. use_bpr=use_bpr,
  478. is_gshard_loss=is_gshard_loss,
  479. gate_noise=gate_noise,
  480. cosine_router_dim=cosine_router_dim,
  481. cosine_router_init_t=cosine_router_init_t,
  482. moe_drop=moe_drop)
  483. for i in range(depth)])
  484. # patch merging layer
  485. if downsample is not None:
  486. self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
  487. else:
  488. self.downsample = None
  489. def forward(self, x):
  490. l_aux = 0.0
  491. for blk in self.blocks:
  492. if self.use_checkpoint:
  493. out = checkpoint.checkpoint(blk, x)
  494. else:
  495. out = blk(x)
  496. if isinstance(out, tuple):
  497. x = out[0]
  498. cur_l_aux = out[1]
  499. l_aux = cur_l_aux + l_aux
  500. else:
  501. x = out
  502. if self.downsample is not None:
  503. x = self.downsample(x)
  504. return x, l_aux
  505. def extra_repr(self) -> str:
  506. return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
  507. def flops(self):
  508. flops = 0
  509. for blk in self.blocks:
  510. flops += blk.flops()
  511. if self.downsample is not None:
  512. flops += self.downsample.flops()
  513. return flops
  514. class PatchEmbed(nn.Module):
  515. r""" Image to Patch Embedding
  516. Args:
  517. img_size (int): Image size. Default: 224.
  518. patch_size (int): Patch token size. Default: 4.
  519. in_chans (int): Number of input image channels. Default: 3.
  520. embed_dim (int): Number of linear projection output channels. Default: 96.
  521. norm_layer (nn.Module, optional): Normalization layer. Default: None
  522. """
  523. def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
  524. super().__init__()
  525. img_size = to_2tuple(img_size)
  526. patch_size = to_2tuple(patch_size)
  527. patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
  528. self.img_size = img_size
  529. self.patch_size = patch_size
  530. self.patches_resolution = patches_resolution
  531. self.num_patches = patches_resolution[0] * patches_resolution[1]
  532. self.in_chans = in_chans
  533. self.embed_dim = embed_dim
  534. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  535. if norm_layer is not None:
  536. self.norm = norm_layer(embed_dim)
  537. else:
  538. self.norm = None
  539. def forward(self, x):
  540. B, C, H, W = x.shape
  541. # FIXME look at relaxing size constraints
  542. assert H == self.img_size[0] and W == self.img_size[1], \
  543. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  544. x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
  545. if self.norm is not None:
  546. x = self.norm(x)
  547. return x
  548. def flops(self):
  549. Ho, Wo = self.patches_resolution
  550. flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
  551. if self.norm is not None:
  552. flops += Ho * Wo * self.embed_dim
  553. return flops
  554. class SwinTransformerMoE(nn.Module):
  555. r""" Swin Transformer
  556. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  557. https://arxiv.org/pdf/2103.14030
  558. Args:
  559. img_size (int | tuple(int)): Input image size. Default 224
  560. patch_size (int | tuple(int)): Patch size. Default: 4
  561. in_chans (int): Number of input image channels. Default: 3
  562. num_classes (int): Number of classes for classification head. Default: 1000
  563. embed_dim (int): Patch embedding dimension. Default: 96
  564. depths (tuple(int)): Depth of each Swin Transformer layer.
  565. num_heads (tuple(int)): Number of attention heads in different layers.
  566. window_size (int): Window size. Default: 7
  567. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  568. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  569. qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
  570. drop_rate (float): Dropout rate. Default: 0
  571. attn_drop_rate (float): Attention dropout rate. Default: 0
  572. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  573. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  574. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
  575. patch_norm (bool): If True, add normalization after patch embedding. Default: True
  576. mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True
  577. init_std: Initialization std. Default: 0.02
  578. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
  579. pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
  580. moe_blocks (tuple(tuple(int))): The index of each MoE block in each layer.
  581. num_local_experts (int): number of local experts in each device (GPU). Default: 1
  582. top_value (int): the value of k in top-k gating. Default: 1
  583. capacity_factor (float): the capacity factor in MoE. Default: 1.25
  584. cosine_router (bool): Whether to use cosine router Default: False
  585. normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False
  586. use_bpr (bool): Whether to use batch-prioritized-routing. Default: True
  587. is_gshard_loss (bool): If True, use Gshard balance loss.
  588. If False, use the load loss and importance loss in "arXiv:1701.06538". Default: False
  589. gate_noise (float): the noise ratio in top-k gating. Default: 1.0
  590. cosine_router_dim (int): Projection dimension in cosine router.
  591. cosine_router_init_t (float): Initialization temperature in cosine router.
  592. moe_drop (float): Dropout rate in MoE. Default: 0.0
  593. aux_loss_weight (float): auxiliary loss weight. Default: 0.1
  594. """
  595. def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
  596. embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
  597. window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
  598. drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
  599. norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
  600. mlp_fc2_bias=True, init_std=0.02, use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0],
  601. moe_blocks=[[-1], [-1], [-1], [-1]], num_local_experts=1, top_value=1, capacity_factor=1.25,
  602. cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0,
  603. cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0, aux_loss_weight=0.01, **kwargs):
  604. super().__init__()
  605. self._ddp_params_and_buffers_to_ignore = list()
  606. self.num_classes = num_classes
  607. self.num_layers = len(depths)
  608. self.embed_dim = embed_dim
  609. self.ape = ape
  610. self.patch_norm = patch_norm
  611. self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
  612. self.mlp_ratio = mlp_ratio
  613. self.init_std = init_std
  614. self.aux_loss_weight = aux_loss_weight
  615. self.num_local_experts = num_local_experts
  616. self.global_experts = num_local_experts * dist.get_world_size() if num_local_experts > 0 \
  617. else dist.get_world_size() // (-num_local_experts)
  618. self.sharded_count = (1.0 / num_local_experts) if num_local_experts > 0 else (-num_local_experts)
  619. # split image into non-overlapping patches
  620. self.patch_embed = PatchEmbed(
  621. img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
  622. norm_layer=norm_layer if self.patch_norm else None)
  623. num_patches = self.patch_embed.num_patches
  624. patches_resolution = self.patch_embed.patches_resolution
  625. self.patches_resolution = patches_resolution
  626. # absolute position embedding
  627. if self.ape:
  628. self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
  629. trunc_normal_(self.absolute_pos_embed, std=self.init_std)
  630. self.pos_drop = nn.Dropout(p=drop_rate)
  631. # stochastic depth
  632. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  633. # build layers
  634. self.layers = nn.ModuleList()
  635. for i_layer in range(self.num_layers):
  636. layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
  637. input_resolution=(patches_resolution[0] // (2 ** i_layer),
  638. patches_resolution[1] // (2 ** i_layer)),
  639. depth=depths[i_layer],
  640. num_heads=num_heads[i_layer],
  641. window_size=window_size,
  642. mlp_ratio=self.mlp_ratio,
  643. qkv_bias=qkv_bias, qk_scale=qk_scale,
  644. drop=drop_rate, attn_drop=attn_drop_rate,
  645. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  646. norm_layer=norm_layer,
  647. downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
  648. mlp_fc2_bias=mlp_fc2_bias,
  649. init_std=init_std,
  650. use_checkpoint=use_checkpoint,
  651. pretrained_window_size=pretrained_window_sizes[i_layer],
  652. moe_block=moe_blocks[i_layer],
  653. num_local_experts=num_local_experts,
  654. top_value=top_value,
  655. capacity_factor=capacity_factor,
  656. cosine_router=cosine_router,
  657. normalize_gate=normalize_gate,
  658. use_bpr=use_bpr,
  659. is_gshard_loss=is_gshard_loss,
  660. gate_noise=gate_noise,
  661. cosine_router_dim=cosine_router_dim,
  662. cosine_router_init_t=cosine_router_init_t,
  663. moe_drop=moe_drop)
  664. self.layers.append(layer)
  665. self.norm = norm_layer(self.num_features)
  666. self.avgpool = nn.AdaptiveAvgPool1d(1)
  667. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  668. self.apply(self._init_weights)
  669. def _init_weights(self, m):
  670. if isinstance(m, nn.Linear):
  671. trunc_normal_(m.weight, std=self.init_std)
  672. if isinstance(m, nn.Linear) and m.bias is not None:
  673. nn.init.constant_(m.bias, 0)
  674. elif isinstance(m, nn.LayerNorm):
  675. nn.init.constant_(m.bias, 0)
  676. nn.init.constant_(m.weight, 1.0)
  677. elif isinstance(m, MoEMlp):
  678. m._init_weights()
  679. @torch.jit.ignore
  680. def no_weight_decay(self):
  681. return {'absolute_pos_embed'}
  682. @torch.jit.ignore
  683. def no_weight_decay_keywords(self):
  684. return {"cpb_mlp", 'relative_position_bias_table', 'fc1_bias', 'fc2_bias',
  685. 'temperature', 'cosine_projector', 'sim_matrix'}
  686. def forward_features(self, x):
  687. x = self.patch_embed(x)
  688. if self.ape:
  689. x = x + self.absolute_pos_embed
  690. x = self.pos_drop(x)
  691. l_aux = 0.0
  692. for layer in self.layers:
  693. x, cur_l_aux = layer(x)
  694. l_aux = cur_l_aux + l_aux
  695. x = self.norm(x) # B L C
  696. x = self.avgpool(x.transpose(1, 2)) # B C 1
  697. x = torch.flatten(x, 1)
  698. return x, l_aux
  699. def forward(self, x):
  700. x, l_aux = self.forward_features(x)
  701. x = self.head(x)
  702. return x, l_aux * self.aux_loss_weight
  703. def add_param_to_skip_allreduce(self, param_name):
  704. self._ddp_params_and_buffers_to_ignore.append(param_name)
  705. def flops(self):
  706. flops = 0
  707. flops += self.patch_embed.flops()
  708. for i, layer in enumerate(self.layers):
  709. flops += layer.flops()
  710. flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
  711. flops += self.num_features * self.num_classes
  712. return flops