swin_transformer.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
  16. Ths copyright of microsoft/Swin-Transformer is as follows:
  17. MIT License [see LICENSE for details]
  18. """
  19. import paddle
  20. import paddle.nn as nn
  21. import paddle.nn.functional as F
  22. from paddle.nn.initializer import TruncatedNormal, Constant, Assign
  23. from ppdet.modeling.shape_spec import ShapeSpec
  24. from ppdet.core.workspace import register, serializable
  25. import numpy as np
  26. # Common initializations
  27. ones_ = Constant(value=1.)
  28. zeros_ = Constant(value=0.)
  29. trunc_normal_ = TruncatedNormal(std=.02)
  30. # Common Functions
  31. def to_2tuple(x):
  32. return tuple([x] * 2)
  33. def add_parameter(layer, datas, name=None):
  34. parameter = layer.create_parameter(
  35. shape=(datas.shape), default_initializer=Assign(datas))
  36. if name:
  37. layer.add_parameter(name, parameter)
  38. return parameter
  39. # Common Layers
  40. def drop_path(x, drop_prob=0., training=False):
  41. """
  42. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  43. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  44. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
  45. """
  46. if drop_prob == 0. or not training:
  47. return x
  48. keep_prob = paddle.to_tensor(1 - drop_prob)
  49. shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
  50. random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
  51. random_tensor = paddle.floor(random_tensor) # binarize
  52. output = x.divide(keep_prob) * random_tensor
  53. return output
  54. class DropPath(nn.Layer):
  55. def __init__(self, drop_prob=None):
  56. super(DropPath, self).__init__()
  57. self.drop_prob = drop_prob
  58. def forward(self, x):
  59. return drop_path(x, self.drop_prob, self.training)
  60. class Identity(nn.Layer):
  61. def __init__(self):
  62. super(Identity, self).__init__()
  63. def forward(self, input):
  64. return input
  65. class Mlp(nn.Layer):
  66. def __init__(self,
  67. in_features,
  68. hidden_features=None,
  69. out_features=None,
  70. act_layer=nn.GELU,
  71. drop=0.):
  72. super().__init__()
  73. out_features = out_features or in_features
  74. hidden_features = hidden_features or in_features
  75. self.fc1 = nn.Linear(in_features, hidden_features)
  76. self.act = act_layer()
  77. self.fc2 = nn.Linear(hidden_features, out_features)
  78. self.drop = nn.Dropout(drop)
  79. def forward(self, x):
  80. x = self.fc1(x)
  81. x = self.act(x)
  82. x = self.drop(x)
  83. x = self.fc2(x)
  84. x = self.drop(x)
  85. return x
  86. def window_partition(x, window_size):
  87. """
  88. Args:
  89. x: (B, H, W, C)
  90. window_size (int): window size
  91. Returns:
  92. windows: (num_windows*B, window_size, window_size, C)
  93. """
  94. B, H, W, C = x.shape
  95. x = x.reshape(
  96. [B, H // window_size, window_size, W // window_size, window_size, C])
  97. windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape(
  98. [-1, window_size, window_size, C])
  99. return windows
  100. def window_reverse(windows, window_size, H, W):
  101. """
  102. Args:
  103. windows: (num_windows*B, window_size, window_size, C)
  104. window_size (int): Window size
  105. H (int): Height of image
  106. W (int): Width of image
  107. Returns:
  108. x: (B, H, W, C)
  109. """
  110. B = int(windows.shape[0] / (H * W / window_size / window_size))
  111. x = windows.reshape(
  112. [B, H // window_size, W // window_size, window_size, window_size, -1])
  113. x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, H, W, -1])
  114. return x
  115. class WindowAttention(nn.Layer):
  116. """ Window based multi-head self attention (W-MSA) module with relative position bias.
  117. It supports both of shifted and non-shifted window.
  118. Args:
  119. dim (int): Number of input channels.
  120. window_size (tuple[int]): The height and width of the window.
  121. num_heads (int): Number of attention heads.
  122. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  123. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
  124. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  125. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  126. """
  127. def __init__(self,
  128. dim,
  129. window_size,
  130. num_heads,
  131. qkv_bias=True,
  132. qk_scale=None,
  133. attn_drop=0.,
  134. proj_drop=0.):
  135. super().__init__()
  136. self.dim = dim
  137. self.window_size = window_size # Wh, Ww
  138. self.num_heads = num_heads
  139. head_dim = dim // num_heads
  140. self.scale = qk_scale or head_dim**-0.5
  141. # define a parameter table of relative position bias
  142. self.relative_position_bias_table = add_parameter(
  143. self,
  144. paddle.zeros(((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
  145. num_heads))) # 2*Wh-1 * 2*Ww-1, nH
  146. # get pair-wise relative position index for each token inside the window
  147. coords_h = paddle.arange(self.window_size[0])
  148. coords_w = paddle.arange(self.window_size[1])
  149. coords = paddle.stack(paddle.meshgrid(
  150. [coords_h, coords_w])) # 2, Wh, Ww
  151. coords_flatten = paddle.flatten(coords, 1) # 2, Wh*Ww
  152. coords_flatten_1 = coords_flatten.unsqueeze(axis=2)
  153. coords_flatten_2 = coords_flatten.unsqueeze(axis=1)
  154. relative_coords = coords_flatten_1 - coords_flatten_2
  155. relative_coords = relative_coords.transpose(
  156. [1, 2, 0]) # Wh*Ww, Wh*Ww, 2
  157. relative_coords[:, :, 0] += self.window_size[
  158. 0] - 1 # shift to start from 0
  159. relative_coords[:, :, 1] += self.window_size[1] - 1
  160. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  161. self.relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  162. self.register_buffer("relative_position_index",
  163. self.relative_position_index)
  164. self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
  165. self.attn_drop = nn.Dropout(attn_drop)
  166. self.proj = nn.Linear(dim, dim)
  167. self.proj_drop = nn.Dropout(proj_drop)
  168. trunc_normal_(self.relative_position_bias_table)
  169. self.softmax = nn.Softmax(axis=-1)
  170. def forward(self, x, mask=None):
  171. """ Forward function.
  172. Args:
  173. x: input features with shape of (num_windows*B, N, C)
  174. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
  175. """
  176. B_, N, C = x.shape
  177. qkv = self.qkv(x).reshape(
  178. [B_, N, 3, self.num_heads, C // self.num_heads]).transpose(
  179. [2, 0, 3, 1, 4])
  180. q, k, v = qkv[0], qkv[1], qkv[2]
  181. q = q * self.scale
  182. attn = paddle.mm(q, k.transpose([0, 1, 3, 2]))
  183. index = self.relative_position_index.reshape([-1])
  184. relative_position_bias = paddle.index_select(
  185. self.relative_position_bias_table, index)
  186. relative_position_bias = relative_position_bias.reshape([
  187. self.window_size[0] * self.window_size[1],
  188. self.window_size[0] * self.window_size[1], -1
  189. ]) # Wh*Ww,Wh*Ww,nH
  190. relative_position_bias = relative_position_bias.transpose(
  191. [2, 0, 1]) # nH, Wh*Ww, Wh*Ww
  192. attn = attn + relative_position_bias.unsqueeze(0)
  193. if mask is not None:
  194. nW = mask.shape[0]
  195. attn = attn.reshape([B_ // nW, nW, self.num_heads, N, N
  196. ]) + mask.unsqueeze(1).unsqueeze(0)
  197. attn = attn.reshape([-1, self.num_heads, N, N])
  198. attn = self.softmax(attn)
  199. else:
  200. attn = self.softmax(attn)
  201. attn = self.attn_drop(attn)
  202. # x = (attn @ v).transpose(1, 2).reshape([B_, N, C])
  203. x = paddle.mm(attn, v).transpose([0, 2, 1, 3]).reshape([B_, N, C])
  204. x = self.proj(x)
  205. x = self.proj_drop(x)
  206. return x
  207. class SwinTransformerBlock(nn.Layer):
  208. """ Swin Transformer Block.
  209. Args:
  210. dim (int): Number of input channels.
  211. num_heads (int): Number of attention heads.
  212. window_size (int): Window size.
  213. shift_size (int): Shift size for SW-MSA.
  214. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  215. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  216. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  217. drop (float, optional): Dropout rate. Default: 0.0
  218. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  219. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  220. act_layer (nn.Layer, optional): Activation layer. Default: nn.GELU
  221. norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
  222. """
  223. def __init__(self,
  224. dim,
  225. num_heads,
  226. window_size=7,
  227. shift_size=0,
  228. mlp_ratio=4.,
  229. qkv_bias=True,
  230. qk_scale=None,
  231. drop=0.,
  232. attn_drop=0.,
  233. drop_path=0.,
  234. act_layer=nn.GELU,
  235. norm_layer=nn.LayerNorm):
  236. super().__init__()
  237. self.dim = dim
  238. self.num_heads = num_heads
  239. self.window_size = window_size
  240. self.shift_size = shift_size
  241. self.mlp_ratio = mlp_ratio
  242. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
  243. self.norm1 = norm_layer(dim)
  244. self.attn = WindowAttention(
  245. dim,
  246. window_size=to_2tuple(self.window_size),
  247. num_heads=num_heads,
  248. qkv_bias=qkv_bias,
  249. qk_scale=qk_scale,
  250. attn_drop=attn_drop,
  251. proj_drop=drop)
  252. self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
  253. self.norm2 = norm_layer(dim)
  254. mlp_hidden_dim = int(dim * mlp_ratio)
  255. self.mlp = Mlp(in_features=dim,
  256. hidden_features=mlp_hidden_dim,
  257. act_layer=act_layer,
  258. drop=drop)
  259. self.H = None
  260. self.W = None
  261. def forward(self, x, mask_matrix):
  262. """ Forward function.
  263. Args:
  264. x: Input feature, tensor size (B, H*W, C).
  265. H, W: Spatial resolution of the input feature.
  266. mask_matrix: Attention mask for cyclic shift.
  267. """
  268. B, L, C = x.shape
  269. H, W = self.H, self.W
  270. assert L == H * W, "input feature has wrong size"
  271. shortcut = x
  272. x = self.norm1(x)
  273. x = x.reshape([B, H, W, C])
  274. # pad feature maps to multiples of window size
  275. pad_l = pad_t = 0
  276. pad_r = (self.window_size - W % self.window_size) % self.window_size
  277. pad_b = (self.window_size - H % self.window_size) % self.window_size
  278. x = F.pad(x, [0, pad_l, 0, pad_b, 0, pad_r, 0, pad_t])
  279. _, Hp, Wp, _ = x.shape
  280. # cyclic shift
  281. if self.shift_size > 0:
  282. shifted_x = paddle.roll(
  283. x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2))
  284. attn_mask = mask_matrix
  285. else:
  286. shifted_x = x
  287. attn_mask = None
  288. # partition windows
  289. x_windows = window_partition(
  290. shifted_x, self.window_size) # nW*B, window_size, window_size, C
  291. x_windows = x_windows.reshape(
  292. [-1, self.window_size * self.window_size,
  293. C]) # nW*B, window_size*window_size, C
  294. # W-MSA/SW-MSA
  295. attn_windows = self.attn(
  296. x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
  297. # merge windows
  298. attn_windows = attn_windows.reshape(
  299. [-1, self.window_size, self.window_size, C])
  300. shifted_x = window_reverse(attn_windows, self.window_size, Hp,
  301. Wp) # B H' W' C
  302. # reverse cyclic shift
  303. if self.shift_size > 0:
  304. x = paddle.roll(
  305. shifted_x,
  306. shifts=(self.shift_size, self.shift_size),
  307. axis=(1, 2))
  308. else:
  309. x = shifted_x
  310. if pad_r > 0 or pad_b > 0:
  311. x = x[:, :H, :W, :]
  312. x = x.reshape([B, H * W, C])
  313. # FFN
  314. x = shortcut + self.drop_path(x)
  315. x = x + self.drop_path(self.mlp(self.norm2(x)))
  316. return x
  317. class PatchMerging(nn.Layer):
  318. r""" Patch Merging Layer.
  319. Args:
  320. dim (int): Number of input channels.
  321. norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
  322. """
  323. def __init__(self, dim, norm_layer=nn.LayerNorm):
  324. super().__init__()
  325. self.dim = dim
  326. self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
  327. self.norm = norm_layer(4 * dim)
  328. def forward(self, x, H, W):
  329. """ Forward function.
  330. Args:
  331. x: Input feature, tensor size (B, H*W, C).
  332. H, W: Spatial resolution of the input feature.
  333. """
  334. B, L, C = x.shape
  335. assert L == H * W, "input feature has wrong size"
  336. x = x.reshape([B, H, W, C])
  337. # padding
  338. pad_input = (H % 2 == 1) or (W % 2 == 1)
  339. if pad_input:
  340. x = F.pad(x, [0, 0, 0, W % 2, 0, H % 2])
  341. x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
  342. x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
  343. x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
  344. x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
  345. x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
  346. x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C
  347. x = self.norm(x)
  348. x = self.reduction(x)
  349. return x
  350. class BasicLayer(nn.Layer):
  351. """ A basic Swin Transformer layer for one stage.
  352. Args:
  353. dim (int): Number of input channels.
  354. input_resolution (tuple[int]): Input resolution.
  355. depth (int): Number of blocks.
  356. num_heads (int): Number of attention heads.
  357. window_size (int): Local window size.
  358. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  359. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  360. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  361. drop (float, optional): Dropout rate. Default: 0.0
  362. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  363. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  364. norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
  365. downsample (nn.Layer | None, optional): Downsample layer at the end of the layer. Default: None
  366. """
  367. def __init__(self,
  368. dim,
  369. depth,
  370. num_heads,
  371. window_size=7,
  372. mlp_ratio=4.,
  373. qkv_bias=True,
  374. qk_scale=None,
  375. drop=0.,
  376. attn_drop=0.,
  377. drop_path=0.,
  378. norm_layer=nn.LayerNorm,
  379. downsample=None):
  380. super().__init__()
  381. self.window_size = window_size
  382. self.shift_size = window_size // 2
  383. self.depth = depth
  384. # build blocks
  385. self.blocks = nn.LayerList([
  386. SwinTransformerBlock(
  387. dim=dim,
  388. num_heads=num_heads,
  389. window_size=window_size,
  390. shift_size=0 if (i % 2 == 0) else window_size // 2,
  391. mlp_ratio=mlp_ratio,
  392. qkv_bias=qkv_bias,
  393. qk_scale=qk_scale,
  394. drop=drop,
  395. attn_drop=attn_drop,
  396. drop_path=drop_path[i]
  397. if isinstance(drop_path, np.ndarray) else drop_path,
  398. norm_layer=norm_layer) for i in range(depth)
  399. ])
  400. # patch merging layer
  401. if downsample is not None:
  402. self.downsample = downsample(dim=dim, norm_layer=norm_layer)
  403. else:
  404. self.downsample = None
  405. def forward(self, x, H, W):
  406. """ Forward function.
  407. Args:
  408. x: Input feature, tensor size (B, H*W, C).
  409. H, W: Spatial resolution of the input feature.
  410. """
  411. # calculate attention mask for SW-MSA
  412. Hp = int(np.ceil(H / self.window_size)) * self.window_size
  413. Wp = int(np.ceil(W / self.window_size)) * self.window_size
  414. img_mask = paddle.fluid.layers.zeros(
  415. [1, Hp, Wp, 1], dtype='float32') # 1 Hp Wp 1
  416. h_slices = (slice(0, -self.window_size),
  417. slice(-self.window_size, -self.shift_size),
  418. slice(-self.shift_size, None))
  419. w_slices = (slice(0, -self.window_size),
  420. slice(-self.window_size, -self.shift_size),
  421. slice(-self.shift_size, None))
  422. cnt = 0
  423. for h in h_slices:
  424. for w in w_slices:
  425. try:
  426. img_mask[:, h, w, :] = cnt
  427. except:
  428. pass
  429. cnt += 1
  430. mask_windows = window_partition(
  431. img_mask, self.window_size) # nW, window_size, window_size, 1
  432. mask_windows = mask_windows.reshape(
  433. [-1, self.window_size * self.window_size])
  434. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  435. huns = -100.0 * paddle.ones_like(attn_mask)
  436. attn_mask = huns * (attn_mask != 0).astype("float32")
  437. for blk in self.blocks:
  438. blk.H, blk.W = H, W
  439. x = blk(x, attn_mask)
  440. if self.downsample is not None:
  441. x_down = self.downsample(x, H, W)
  442. Wh, Ww = (H + 1) // 2, (W + 1) // 2
  443. return x, H, W, x_down, Wh, Ww
  444. else:
  445. return x, H, W, x, H, W
  446. class PatchEmbed(nn.Layer):
  447. """ Image to Patch Embedding
  448. Args:
  449. patch_size (int): Patch token size. Default: 4.
  450. in_chans (int): Number of input image channels. Default: 3.
  451. embed_dim (int): Number of linear projection output channels. Default: 96.
  452. norm_layer (nn.Layer, optional): Normalization layer. Default: None
  453. """
  454. def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
  455. super().__init__()
  456. patch_size = to_2tuple(patch_size)
  457. self.patch_size = patch_size
  458. self.in_chans = in_chans
  459. self.embed_dim = embed_dim
  460. self.proj = nn.Conv2D(
  461. in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  462. if norm_layer is not None:
  463. self.norm = norm_layer(embed_dim)
  464. else:
  465. self.norm = None
  466. def forward(self, x):
  467. B, C, H, W = x.shape
  468. # assert [H, W] == self.img_size[:2], "Input image size ({H}*{W}) doesn't match model ({}*{}).".format(H, W, self.img_size[0], self.img_size[1])
  469. if W % self.patch_size[1] != 0:
  470. x = F.pad(x, [0, self.patch_size[1] - W % self.patch_size[1], 0, 0])
  471. if H % self.patch_size[0] != 0:
  472. x = F.pad(x, [0, 0, 0, self.patch_size[0] - H % self.patch_size[0]])
  473. x = self.proj(x)
  474. if self.norm is not None:
  475. _, _, Wh, Ww = x.shape
  476. x = x.flatten(2).transpose([0, 2, 1])
  477. x = self.norm(x)
  478. x = x.transpose([0, 2, 1]).reshape([-1, self.embed_dim, Wh, Ww])
  479. return x
  480. @register
  481. @serializable
  482. class SwinTransformer(nn.Layer):
  483. """ Swin Transformer
  484. A PaddlePaddle impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  485. https://arxiv.org/pdf/2103.14030
  486. Args:
  487. img_size (int | tuple(int)): Input image size. Default 224
  488. patch_size (int | tuple(int)): Patch size. Default: 4
  489. in_chans (int): Number of input image channels. Default: 3
  490. num_classes (int): Number of classes for classification head. Default: 1000
  491. embed_dim (int): Patch embedding dimension. Default: 96
  492. depths (tuple(int)): Depth of each Swin Transformer layer.
  493. num_heads (tuple(int)): Number of attention heads in different layers.
  494. window_size (int): Window size. Default: 7
  495. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  496. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  497. qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
  498. drop_rate (float): Dropout rate. Default: 0
  499. attn_drop_rate (float): Attention dropout rate. Default: 0
  500. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  501. norm_layer (nn.Layer): Normalization layer. Default: nn.LayerNorm.
  502. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
  503. patch_norm (bool): If True, add normalization after patch embedding. Default: True
  504. """
  505. def __init__(self,
  506. pretrain_img_size=224,
  507. patch_size=4,
  508. in_chans=3,
  509. embed_dim=96,
  510. depths=[2, 2, 6, 2],
  511. num_heads=[3, 6, 12, 24],
  512. window_size=7,
  513. mlp_ratio=4.,
  514. qkv_bias=True,
  515. qk_scale=None,
  516. drop_rate=0.,
  517. attn_drop_rate=0.,
  518. drop_path_rate=0.2,
  519. norm_layer=nn.LayerNorm,
  520. ape=False,
  521. patch_norm=True,
  522. out_indices=(0, 1, 2, 3),
  523. frozen_stages=-1,
  524. pretrained=None):
  525. super(SwinTransformer, self).__init__()
  526. self.pretrain_img_size = pretrain_img_size
  527. self.num_layers = len(depths)
  528. self.embed_dim = embed_dim
  529. self.ape = ape
  530. self.patch_norm = patch_norm
  531. self.out_indices = out_indices
  532. self.frozen_stages = frozen_stages
  533. # split image into non-overlapping patches
  534. self.patch_embed = PatchEmbed(
  535. patch_size=patch_size,
  536. in_chans=in_chans,
  537. embed_dim=embed_dim,
  538. norm_layer=norm_layer if self.patch_norm else None)
  539. # absolute position embedding
  540. if self.ape:
  541. pretrain_img_size = to_2tuple(pretrain_img_size)
  542. patch_size = to_2tuple(patch_size)
  543. patches_resolution = [
  544. pretrain_img_size[0] // patch_size[0],
  545. pretrain_img_size[1] // patch_size[1]
  546. ]
  547. self.absolute_pos_embed = add_parameter(
  548. self,
  549. paddle.zeros((1, embed_dim, patches_resolution[0],
  550. patches_resolution[1])))
  551. trunc_normal_(self.absolute_pos_embed)
  552. self.pos_drop = nn.Dropout(p=drop_rate)
  553. # stochastic depth
  554. dpr = np.linspace(0, drop_path_rate,
  555. sum(depths)) # stochastic depth decay rule
  556. # build layers
  557. self.layers = nn.LayerList()
  558. for i_layer in range(self.num_layers):
  559. layer = BasicLayer(
  560. dim=int(embed_dim * 2**i_layer),
  561. depth=depths[i_layer],
  562. num_heads=num_heads[i_layer],
  563. window_size=window_size,
  564. mlp_ratio=mlp_ratio,
  565. qkv_bias=qkv_bias,
  566. qk_scale=qk_scale,
  567. drop=drop_rate,
  568. attn_drop=attn_drop_rate,
  569. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  570. norm_layer=norm_layer,
  571. downsample=PatchMerging
  572. if (i_layer < self.num_layers - 1) else None)
  573. self.layers.append(layer)
  574. num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
  575. self.num_features = num_features
  576. # add a norm layer for each output
  577. for i_layer in out_indices:
  578. layer = norm_layer(num_features[i_layer])
  579. layer_name = f'norm{i_layer}'
  580. self.add_sublayer(layer_name, layer)
  581. self.apply(self._init_weights)
  582. self._freeze_stages()
  583. if pretrained:
  584. if 'http' in pretrained: #URL
  585. path = paddle.utils.download.get_weights_path_from_url(
  586. pretrained)
  587. else: #model in local path
  588. path = pretrained
  589. self.set_state_dict(paddle.load(path))
  590. def _freeze_stages(self):
  591. if self.frozen_stages >= 0:
  592. self.patch_embed.eval()
  593. for param in self.patch_embed.parameters():
  594. param.stop_gradient = True
  595. if self.frozen_stages >= 1 and self.ape:
  596. self.absolute_pos_embed.stop_gradient = True
  597. if self.frozen_stages >= 2:
  598. self.pos_drop.eval()
  599. for i in range(0, self.frozen_stages - 1):
  600. m = self.layers[i]
  601. m.eval()
  602. for param in m.parameters():
  603. param.stop_gradient = True
  604. def _init_weights(self, m):
  605. if isinstance(m, nn.Linear):
  606. trunc_normal_(m.weight)
  607. if isinstance(m, nn.Linear) and m.bias is not None:
  608. zeros_(m.bias)
  609. elif isinstance(m, nn.LayerNorm):
  610. zeros_(m.bias)
  611. ones_(m.weight)
  612. def forward(self, x):
  613. """Forward function."""
  614. x = self.patch_embed(x['image'])
  615. _, _, Wh, Ww = x.shape
  616. if self.ape:
  617. # interpolate the position embedding to the corresponding size
  618. absolute_pos_embed = F.interpolate(
  619. self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
  620. x = (x + absolute_pos_embed).flatten(2).transpose([0, 2, 1])
  621. else:
  622. x = x.flatten(2).transpose([0, 2, 1])
  623. x = self.pos_drop(x)
  624. outs = []
  625. for i in range(self.num_layers):
  626. layer = self.layers[i]
  627. x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
  628. if i in self.out_indices:
  629. norm_layer = getattr(self, f'norm{i}')
  630. x_out = norm_layer(x_out)
  631. out = x_out.reshape((-1, H, W, self.num_features[i])).transpose(
  632. (0, 3, 1, 2))
  633. outs.append(out)
  634. return tuple(outs)
  635. @property
  636. def out_shape(self):
  637. out_strides = [4, 8, 16, 32]
  638. return [
  639. ShapeSpec(
  640. channels=self.num_features[i], stride=out_strides[i])
  641. for i in self.out_indices
  642. ]