s2anet_head.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # The code is based on https://github.com/csuhan/s2anet/blob/master/mmdet/models/anchor_heads_rotated/s2anet_head.py
  16. import paddle
  17. from paddle import ParamAttr
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle.nn.initializer import Normal, Constant
  21. from ppdet.core.workspace import register
  22. from ppdet.modeling import ops
  23. from ppdet.modeling import bbox_utils
  24. from ppdet.modeling.proposal_generator.target_layer import RBoxAssigner
  25. import numpy as np
  26. class S2ANetAnchorGenerator(nn.Layer):
  27. """
  28. AnchorGenerator by paddle
  29. """
  30. def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
  31. super(S2ANetAnchorGenerator, self).__init__()
  32. self.base_size = base_size
  33. self.scales = paddle.to_tensor(scales)
  34. self.ratios = paddle.to_tensor(ratios)
  35. self.scale_major = scale_major
  36. self.ctr = ctr
  37. self.base_anchors = self.gen_base_anchors()
  38. @property
  39. def num_base_anchors(self):
  40. return self.base_anchors.shape[0]
  41. def gen_base_anchors(self):
  42. w = self.base_size
  43. h = self.base_size
  44. if self.ctr is None:
  45. x_ctr = 0.5 * (w - 1)
  46. y_ctr = 0.5 * (h - 1)
  47. else:
  48. x_ctr, y_ctr = self.ctr
  49. h_ratios = paddle.sqrt(self.ratios)
  50. w_ratios = 1 / h_ratios
  51. if self.scale_major:
  52. ws = (w * w_ratios[:] * self.scales[:]).reshape([-1])
  53. hs = (h * h_ratios[:] * self.scales[:]).reshape([-1])
  54. else:
  55. ws = (w * self.scales[:] * w_ratios[:]).reshape([-1])
  56. hs = (h * self.scales[:] * h_ratios[:]).reshape([-1])
  57. base_anchors = paddle.stack(
  58. [
  59. x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
  60. x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
  61. ],
  62. axis=-1)
  63. base_anchors = paddle.round(base_anchors)
  64. return base_anchors
  65. def _meshgrid(self, x, y, row_major=True):
  66. yy, xx = paddle.meshgrid(y, x)
  67. yy = yy.reshape([-1])
  68. xx = xx.reshape([-1])
  69. if row_major:
  70. return xx, yy
  71. else:
  72. return yy, xx
  73. def forward(self, featmap_size, stride=16):
  74. # featmap_size*stride project it to original area
  75. feat_h = featmap_size[0]
  76. feat_w = featmap_size[1]
  77. shift_x = paddle.arange(0, feat_w, 1, 'int32') * stride
  78. shift_y = paddle.arange(0, feat_h, 1, 'int32') * stride
  79. shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
  80. shifts = paddle.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
  81. all_anchors = self.base_anchors[:, :] + shifts[:, :]
  82. all_anchors = all_anchors.reshape([feat_h * feat_w, 4])
  83. return all_anchors
  84. def valid_flags(self, featmap_size, valid_size):
  85. feat_h, feat_w = featmap_size
  86. valid_h, valid_w = valid_size
  87. assert valid_h <= feat_h and valid_w <= feat_w
  88. valid_x = paddle.zeros([feat_w], dtype='int32')
  89. valid_y = paddle.zeros([feat_h], dtype='int32')
  90. valid_x[:valid_w] = 1
  91. valid_y[:valid_h] = 1
  92. valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
  93. valid = valid_xx & valid_yy
  94. valid = paddle.reshape(valid, [-1, 1])
  95. valid = paddle.expand(valid, [-1, self.num_base_anchors]).reshape([-1])
  96. return valid
  97. class AlignConv(nn.Layer):
  98. def __init__(self, in_channels, out_channels, kernel_size=3, groups=1):
  99. super(AlignConv, self).__init__()
  100. self.kernel_size = kernel_size
  101. self.align_conv = paddle.vision.ops.DeformConv2D(
  102. in_channels,
  103. out_channels,
  104. kernel_size=self.kernel_size,
  105. padding=(self.kernel_size - 1) // 2,
  106. groups=groups,
  107. weight_attr=ParamAttr(initializer=Normal(0, 0.01)),
  108. bias_attr=None)
  109. @paddle.no_grad()
  110. def get_offset(self, anchors, featmap_size, stride):
  111. """
  112. Args:
  113. anchors: [M,5] xc,yc,w,h,angle
  114. featmap_size: (feat_h, feat_w)
  115. stride: 8
  116. Returns:
  117. """
  118. anchors = paddle.reshape(anchors, [-1, 5]) # (NA,5)
  119. dtype = anchors.dtype
  120. feat_h = featmap_size[0]
  121. feat_w = featmap_size[1]
  122. pad = (self.kernel_size - 1) // 2
  123. idx = paddle.arange(-pad, pad + 1, dtype=dtype)
  124. yy, xx = paddle.meshgrid(idx, idx)
  125. xx = paddle.reshape(xx, [-1])
  126. yy = paddle.reshape(yy, [-1])
  127. # get sampling locations of default conv
  128. xc = paddle.arange(0, feat_w, dtype=dtype)
  129. yc = paddle.arange(0, feat_h, dtype=dtype)
  130. yc, xc = paddle.meshgrid(yc, xc)
  131. xc = paddle.reshape(xc, [-1, 1])
  132. yc = paddle.reshape(yc, [-1, 1])
  133. x_conv = xc + xx
  134. y_conv = yc + yy
  135. # get sampling locations of anchors
  136. # x_ctr, y_ctr, w, h, a = np.unbind(anchors, dim=1)
  137. x_ctr = anchors[:, 0]
  138. y_ctr = anchors[:, 1]
  139. w = anchors[:, 2]
  140. h = anchors[:, 3]
  141. a = anchors[:, 4]
  142. x_ctr = paddle.reshape(x_ctr, [-1, 1])
  143. y_ctr = paddle.reshape(y_ctr, [-1, 1])
  144. w = paddle.reshape(w, [-1, 1])
  145. h = paddle.reshape(h, [-1, 1])
  146. a = paddle.reshape(a, [-1, 1])
  147. x_ctr = x_ctr / stride
  148. y_ctr = y_ctr / stride
  149. w_s = w / stride
  150. h_s = h / stride
  151. cos, sin = paddle.cos(a), paddle.sin(a)
  152. dw, dh = w_s / self.kernel_size, h_s / self.kernel_size
  153. x, y = dw * xx, dh * yy
  154. xr = cos * x - sin * y
  155. yr = sin * x + cos * y
  156. x_anchor, y_anchor = xr + x_ctr, yr + y_ctr
  157. # get offset filed
  158. offset_x = x_anchor - x_conv
  159. offset_y = y_anchor - y_conv
  160. offset = paddle.stack([offset_y, offset_x], axis=-1)
  161. offset = paddle.reshape(
  162. offset, [feat_h * feat_w, self.kernel_size * self.kernel_size * 2])
  163. offset = paddle.transpose(offset, [1, 0])
  164. offset = paddle.reshape(
  165. offset,
  166. [1, self.kernel_size * self.kernel_size * 2, feat_h, feat_w])
  167. return offset
  168. def forward(self, x, refine_anchors, featmap_size, stride):
  169. offset = self.get_offset(refine_anchors, featmap_size, stride)
  170. x = F.relu(self.align_conv(x, offset))
  171. return x
  172. @register
  173. class S2ANetHead(nn.Layer):
  174. """
  175. S2Anet head
  176. Args:
  177. stacked_convs (int): number of stacked_convs
  178. feat_in (int): input channels of feat
  179. feat_out (int): output channels of feat
  180. num_classes (int): num_classes
  181. anchor_strides (list): stride of anchors
  182. anchor_scales (list): scale of anchors
  183. anchor_ratios (list): ratios of anchors
  184. target_means (list): target_means
  185. target_stds (list): target_stds
  186. align_conv_type (str): align_conv_type ['Conv', 'AlignConv']
  187. align_conv_size (int): kernel size of align_conv
  188. use_sigmoid_cls (bool): use sigmoid_cls or not
  189. reg_loss_weight (list): loss weight for regression
  190. """
  191. __shared__ = ['num_classes']
  192. __inject__ = ['anchor_assign']
  193. def __init__(self,
  194. stacked_convs=2,
  195. feat_in=256,
  196. feat_out=256,
  197. num_classes=15,
  198. anchor_strides=[8, 16, 32, 64, 128],
  199. anchor_scales=[4],
  200. anchor_ratios=[1.0],
  201. target_means=0.0,
  202. target_stds=1.0,
  203. align_conv_type='AlignConv',
  204. align_conv_size=3,
  205. use_sigmoid_cls=True,
  206. anchor_assign=RBoxAssigner().__dict__,
  207. reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.1],
  208. cls_loss_weight=[1.1, 1.05],
  209. reg_loss_type='l1'):
  210. super(S2ANetHead, self).__init__()
  211. self.stacked_convs = stacked_convs
  212. self.feat_in = feat_in
  213. self.feat_out = feat_out
  214. self.anchor_list = None
  215. self.anchor_scales = anchor_scales
  216. self.anchor_ratios = anchor_ratios
  217. self.anchor_strides = anchor_strides
  218. self.anchor_strides = paddle.to_tensor(anchor_strides)
  219. self.anchor_base_sizes = list(anchor_strides)
  220. self.means = paddle.ones(shape=[5]) * target_means
  221. self.stds = paddle.ones(shape=[5]) * target_stds
  222. assert align_conv_type in ['AlignConv', 'Conv', 'DCN']
  223. self.align_conv_type = align_conv_type
  224. self.align_conv_size = align_conv_size
  225. self.use_sigmoid_cls = use_sigmoid_cls
  226. self.cls_out_channels = num_classes if self.use_sigmoid_cls else 1
  227. self.sampling = False
  228. self.anchor_assign = anchor_assign
  229. self.reg_loss_weight = reg_loss_weight
  230. self.cls_loss_weight = cls_loss_weight
  231. self.alpha = 1.0
  232. self.beta = 1.0
  233. self.reg_loss_type = reg_loss_type
  234. self.s2anet_head_out = None
  235. # anchor
  236. self.anchor_generators = []
  237. for anchor_base in self.anchor_base_sizes:
  238. self.anchor_generators.append(
  239. S2ANetAnchorGenerator(anchor_base, anchor_scales,
  240. anchor_ratios))
  241. self.anchor_generators = nn.LayerList(self.anchor_generators)
  242. self.fam_cls_convs = nn.Sequential()
  243. self.fam_reg_convs = nn.Sequential()
  244. for i in range(self.stacked_convs):
  245. chan_in = self.feat_in if i == 0 else self.feat_out
  246. self.fam_cls_convs.add_sublayer(
  247. 'fam_cls_conv_{}'.format(i),
  248. nn.Conv2D(
  249. in_channels=chan_in,
  250. out_channels=self.feat_out,
  251. kernel_size=3,
  252. padding=1,
  253. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  254. bias_attr=ParamAttr(initializer=Constant(0))))
  255. self.fam_cls_convs.add_sublayer('fam_cls_conv_{}_act'.format(i),
  256. nn.ReLU())
  257. self.fam_reg_convs.add_sublayer(
  258. 'fam_reg_conv_{}'.format(i),
  259. nn.Conv2D(
  260. in_channels=chan_in,
  261. out_channels=self.feat_out,
  262. kernel_size=3,
  263. padding=1,
  264. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  265. bias_attr=ParamAttr(initializer=Constant(0))))
  266. self.fam_reg_convs.add_sublayer('fam_reg_conv_{}_act'.format(i),
  267. nn.ReLU())
  268. self.fam_reg = nn.Conv2D(
  269. self.feat_out,
  270. 5,
  271. 1,
  272. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  273. bias_attr=ParamAttr(initializer=Constant(0)))
  274. prior_prob = 0.01
  275. bias_init = float(-np.log((1 - prior_prob) / prior_prob))
  276. self.fam_cls = nn.Conv2D(
  277. self.feat_out,
  278. self.cls_out_channels,
  279. 1,
  280. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  281. bias_attr=ParamAttr(initializer=Constant(bias_init)))
  282. if self.align_conv_type == "AlignConv":
  283. self.align_conv = AlignConv(self.feat_out, self.feat_out,
  284. self.align_conv_size)
  285. elif self.align_conv_type == "Conv":
  286. self.align_conv = nn.Conv2D(
  287. self.feat_out,
  288. self.feat_out,
  289. self.align_conv_size,
  290. padding=(self.align_conv_size - 1) // 2,
  291. bias_attr=ParamAttr(initializer=Constant(0)))
  292. elif self.align_conv_type == "DCN":
  293. self.align_conv_offset = nn.Conv2D(
  294. self.feat_out,
  295. 2 * self.align_conv_size**2,
  296. 1,
  297. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  298. bias_attr=ParamAttr(initializer=Constant(0)))
  299. self.align_conv = paddle.vision.ops.DeformConv2D(
  300. self.feat_out,
  301. self.feat_out,
  302. self.align_conv_size,
  303. padding=(self.align_conv_size - 1) // 2,
  304. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  305. bias_attr=False)
  306. self.or_conv = nn.Conv2D(
  307. self.feat_out,
  308. self.feat_out,
  309. kernel_size=3,
  310. padding=1,
  311. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  312. bias_attr=ParamAttr(initializer=Constant(0)))
  313. # ODM
  314. self.odm_cls_convs = nn.Sequential()
  315. self.odm_reg_convs = nn.Sequential()
  316. for i in range(self.stacked_convs):
  317. ch_in = self.feat_out
  318. # ch_in = int(self.feat_out / 8) if i == 0 else self.feat_out
  319. self.odm_cls_convs.add_sublayer(
  320. 'odm_cls_conv_{}'.format(i),
  321. nn.Conv2D(
  322. in_channels=ch_in,
  323. out_channels=self.feat_out,
  324. kernel_size=3,
  325. stride=1,
  326. padding=1,
  327. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  328. bias_attr=ParamAttr(initializer=Constant(0))))
  329. self.odm_cls_convs.add_sublayer('odm_cls_conv_{}_act'.format(i),
  330. nn.ReLU())
  331. self.odm_reg_convs.add_sublayer(
  332. 'odm_reg_conv_{}'.format(i),
  333. nn.Conv2D(
  334. in_channels=self.feat_out,
  335. out_channels=self.feat_out,
  336. kernel_size=3,
  337. stride=1,
  338. padding=1,
  339. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  340. bias_attr=ParamAttr(initializer=Constant(0))))
  341. self.odm_reg_convs.add_sublayer('odm_reg_conv_{}_act'.format(i),
  342. nn.ReLU())
  343. self.odm_cls = nn.Conv2D(
  344. self.feat_out,
  345. self.cls_out_channels,
  346. 3,
  347. padding=1,
  348. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  349. bias_attr=ParamAttr(initializer=Constant(bias_init)))
  350. self.odm_reg = nn.Conv2D(
  351. self.feat_out,
  352. 5,
  353. 3,
  354. padding=1,
  355. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  356. bias_attr=ParamAttr(initializer=Constant(0)))
  357. self.featmap_sizes = []
  358. self.base_anchors_list = []
  359. self.refine_anchor_list = []
  360. def forward(self, feats):
  361. fam_reg_branch_list = []
  362. fam_cls_branch_list = []
  363. odm_reg_branch_list = []
  364. odm_cls_branch_list = []
  365. self.featmap_sizes_list = []
  366. self.base_anchors_list = []
  367. self.refine_anchor_list = []
  368. for feat_idx in range(len(feats)):
  369. feat = feats[feat_idx]
  370. fam_cls_feat = self.fam_cls_convs(feat)
  371. fam_cls = self.fam_cls(fam_cls_feat)
  372. # [N, CLS, H, W] --> [N, H, W, CLS]
  373. fam_cls = fam_cls.transpose([0, 2, 3, 1])
  374. fam_cls_reshape = paddle.reshape(
  375. fam_cls, [fam_cls.shape[0], -1, self.cls_out_channels])
  376. fam_cls_branch_list.append(fam_cls_reshape)
  377. fam_reg_feat = self.fam_reg_convs(feat)
  378. fam_reg = self.fam_reg(fam_reg_feat)
  379. # [N, 5, H, W] --> [N, H, W, 5]
  380. fam_reg = fam_reg.transpose([0, 2, 3, 1])
  381. fam_reg_reshape = paddle.reshape(fam_reg, [fam_reg.shape[0], -1, 5])
  382. fam_reg_branch_list.append(fam_reg_reshape)
  383. # prepare anchor
  384. featmap_size = (paddle.shape(feat)[2], paddle.shape(feat)[3])
  385. self.featmap_sizes_list.append(featmap_size)
  386. init_anchors = self.anchor_generators[feat_idx](
  387. featmap_size, self.anchor_strides[feat_idx])
  388. init_anchors = paddle.to_tensor(init_anchors, dtype='float32')
  389. NA = featmap_size[0] * featmap_size[1]
  390. init_anchors = paddle.reshape(init_anchors, [NA, 4])
  391. init_anchors = self.rect2rbox(init_anchors)
  392. self.base_anchors_list.append(init_anchors)
  393. if self.training:
  394. refine_anchor = self.bbox_decode(fam_reg.detach(), init_anchors)
  395. else:
  396. refine_anchor = self.bbox_decode(fam_reg, init_anchors)
  397. self.refine_anchor_list.append(refine_anchor)
  398. if self.align_conv_type == 'AlignConv':
  399. align_feat = self.align_conv(feat,
  400. refine_anchor.clone(),
  401. featmap_size,
  402. self.anchor_strides[feat_idx])
  403. elif self.align_conv_type == 'DCN':
  404. align_offset = self.align_conv_offset(feat)
  405. align_feat = self.align_conv(feat, align_offset)
  406. elif self.align_conv_type == 'Conv':
  407. align_feat = self.align_conv(feat)
  408. or_feat = self.or_conv(align_feat)
  409. odm_reg_feat = or_feat
  410. odm_cls_feat = or_feat
  411. odm_reg_feat = self.odm_reg_convs(odm_reg_feat)
  412. odm_cls_feat = self.odm_cls_convs(odm_cls_feat)
  413. odm_cls_score = self.odm_cls(odm_cls_feat)
  414. # [N, CLS, H, W] --> [N, H, W, CLS]
  415. odm_cls_score = odm_cls_score.transpose([0, 2, 3, 1])
  416. odm_cls_score_shape = odm_cls_score.shape
  417. odm_cls_score_reshape = paddle.reshape(odm_cls_score, [
  418. odm_cls_score_shape[0], odm_cls_score_shape[1] *
  419. odm_cls_score_shape[2], self.cls_out_channels
  420. ])
  421. odm_cls_branch_list.append(odm_cls_score_reshape)
  422. odm_bbox_pred = self.odm_reg(odm_reg_feat)
  423. # [N, 5, H, W] --> [N, H, W, 5]
  424. odm_bbox_pred = odm_bbox_pred.transpose([0, 2, 3, 1])
  425. odm_bbox_pred_reshape = paddle.reshape(odm_bbox_pred, [-1, 5])
  426. odm_bbox_pred_reshape = paddle.unsqueeze(
  427. odm_bbox_pred_reshape, axis=0)
  428. odm_reg_branch_list.append(odm_bbox_pred_reshape)
  429. self.s2anet_head_out = (fam_cls_branch_list, fam_reg_branch_list,
  430. odm_cls_branch_list, odm_reg_branch_list)
  431. return self.s2anet_head_out
  432. def get_prediction(self, nms_pre=2000):
  433. refine_anchors = self.refine_anchor_list
  434. fam_cls_branch_list = self.s2anet_head_out[0]
  435. fam_reg_branch_list = self.s2anet_head_out[1]
  436. odm_cls_branch_list = self.s2anet_head_out[2]
  437. odm_reg_branch_list = self.s2anet_head_out[3]
  438. pred_scores, pred_bboxes = self.get_bboxes(
  439. odm_cls_branch_list, odm_reg_branch_list, refine_anchors, nms_pre,
  440. self.cls_out_channels, self.use_sigmoid_cls)
  441. return pred_scores, pred_bboxes
  442. def smooth_l1_loss(self, pred, label, delta=1.0 / 9.0):
  443. """
  444. Args:
  445. pred: pred score
  446. label: label
  447. delta: delta
  448. Returns: loss
  449. """
  450. assert pred.shape == label.shape and label.numel() > 0
  451. assert delta > 0
  452. diff = paddle.abs(pred - label)
  453. loss = paddle.where(diff < delta, 0.5 * diff * diff / delta,
  454. diff - 0.5 * delta)
  455. return loss
  456. def get_fam_loss(self, fam_target, s2anet_head_out, reg_loss_type='gwd'):
  457. (labels, label_weights, bbox_targets, bbox_weights, bbox_gt_bboxes,
  458. pos_inds, neg_inds) = fam_target
  459. fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out
  460. fam_cls_losses = []
  461. fam_bbox_losses = []
  462. st_idx = 0
  463. num_total_samples = len(pos_inds) + len(
  464. neg_inds) if self.sampling else len(pos_inds)
  465. num_total_samples = max(1, num_total_samples)
  466. for idx, feat_size in enumerate(self.featmap_sizes_list):
  467. feat_anchor_num = feat_size[0] * feat_size[1]
  468. # step1: get data
  469. feat_labels = labels[st_idx:st_idx + feat_anchor_num]
  470. feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]
  471. feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :]
  472. feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :]
  473. # step2: calc cls loss
  474. feat_labels = feat_labels.reshape(-1)
  475. feat_label_weights = feat_label_weights.reshape(-1)
  476. fam_cls_score = fam_cls_branch_list[idx]
  477. fam_cls_score = paddle.squeeze(fam_cls_score, axis=0)
  478. fam_cls_score1 = fam_cls_score
  479. feat_labels = paddle.to_tensor(feat_labels)
  480. feat_labels_one_hot = paddle.nn.functional.one_hot(
  481. feat_labels, self.cls_out_channels + 1)
  482. feat_labels_one_hot = feat_labels_one_hot[:, 1:]
  483. feat_labels_one_hot.stop_gradient = True
  484. num_total_samples = paddle.to_tensor(
  485. num_total_samples, dtype='float32', stop_gradient=True)
  486. fam_cls = F.sigmoid_focal_loss(
  487. fam_cls_score1,
  488. feat_labels_one_hot,
  489. normalizer=num_total_samples,
  490. reduction='none')
  491. feat_label_weights = feat_label_weights.reshape(
  492. feat_label_weights.shape[0], 1)
  493. feat_label_weights = np.repeat(
  494. feat_label_weights, self.cls_out_channels, axis=1)
  495. feat_label_weights = paddle.to_tensor(
  496. feat_label_weights, stop_gradient=True)
  497. fam_cls = fam_cls * feat_label_weights
  498. fam_cls_total = paddle.sum(fam_cls)
  499. fam_cls_losses.append(fam_cls_total)
  500. # step3: regression loss
  501. feat_bbox_targets = paddle.to_tensor(
  502. feat_bbox_targets, dtype='float32', stop_gradient=True)
  503. feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
  504. fam_bbox_pred = fam_reg_branch_list[idx]
  505. fam_bbox_pred = paddle.squeeze(fam_bbox_pred, axis=0)
  506. fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5])
  507. fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets)
  508. loss_weight = paddle.to_tensor(
  509. self.reg_loss_weight, dtype='float32', stop_gradient=True)
  510. fam_bbox = paddle.multiply(fam_bbox, loss_weight)
  511. feat_bbox_weights = paddle.to_tensor(
  512. feat_bbox_weights, stop_gradient=True)
  513. if reg_loss_type == 'l1':
  514. fam_bbox = fam_bbox * feat_bbox_weights
  515. fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples
  516. elif reg_loss_type == 'iou' or reg_loss_type == 'gwd':
  517. fam_bbox = paddle.sum(fam_bbox, axis=-1)
  518. feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1)
  519. try:
  520. from rbox_iou_ops import rbox_iou
  521. except Exception as e:
  522. print("import custom_ops error, try install rbox_iou_ops " \
  523. "following ppdet/ext_op/README.md", e)
  524. sys.stdout.flush()
  525. sys.exit(-1)
  526. # calc iou
  527. fam_bbox_decode = self.delta2rbox(self.base_anchors_list[idx],
  528. fam_bbox_pred)
  529. bbox_gt_bboxes = paddle.to_tensor(
  530. bbox_gt_bboxes,
  531. dtype=fam_bbox_decode.dtype,
  532. place=fam_bbox_decode.place)
  533. bbox_gt_bboxes.stop_gradient = True
  534. iou = rbox_iou(fam_bbox_decode, bbox_gt_bboxes)
  535. iou = paddle.diag(iou)
  536. if reg_loss_type == 'gwd':
  537. bbox_gt_bboxes_level = bbox_gt_bboxes[st_idx:st_idx +
  538. feat_anchor_num, :]
  539. fam_bbox_total = self.gwd_loss(fam_bbox_decode,
  540. bbox_gt_bboxes_level)
  541. fam_bbox_total = fam_bbox_total * feat_bbox_weights
  542. fam_bbox_total = paddle.sum(
  543. fam_bbox_total) / num_total_samples
  544. fam_bbox_losses.append(fam_bbox_total)
  545. st_idx += feat_anchor_num
  546. fam_cls_loss = paddle.add_n(fam_cls_losses)
  547. fam_cls_loss_weight = paddle.to_tensor(
  548. self.cls_loss_weight[0], dtype='float32', stop_gradient=True)
  549. fam_cls_loss = fam_cls_loss * fam_cls_loss_weight
  550. fam_reg_loss = paddle.add_n(fam_bbox_losses)
  551. return fam_cls_loss, fam_reg_loss
  552. def get_odm_loss(self, odm_target, s2anet_head_out, reg_loss_type='gwd'):
  553. (labels, label_weights, bbox_targets, bbox_weights, bbox_gt_bboxes,
  554. pos_inds, neg_inds) = odm_target
  555. fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out
  556. odm_cls_losses = []
  557. odm_bbox_losses = []
  558. st_idx = 0
  559. num_total_samples = len(pos_inds) + len(
  560. neg_inds) if self.sampling else len(pos_inds)
  561. num_total_samples = max(1, num_total_samples)
  562. for idx, feat_size in enumerate(self.featmap_sizes_list):
  563. feat_anchor_num = feat_size[0] * feat_size[1]
  564. # step1: get data
  565. feat_labels = labels[st_idx:st_idx + feat_anchor_num]
  566. feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]
  567. feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :]
  568. feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :]
  569. # step2: calc cls loss
  570. feat_labels = feat_labels.reshape(-1)
  571. feat_label_weights = feat_label_weights.reshape(-1)
  572. odm_cls_score = odm_cls_branch_list[idx]
  573. odm_cls_score = paddle.squeeze(odm_cls_score, axis=0)
  574. odm_cls_score1 = odm_cls_score
  575. feat_labels = paddle.to_tensor(feat_labels)
  576. feat_labels_one_hot = paddle.nn.functional.one_hot(
  577. feat_labels, self.cls_out_channels + 1)
  578. feat_labels_one_hot = feat_labels_one_hot[:, 1:]
  579. feat_labels_one_hot.stop_gradient = True
  580. num_total_samples = paddle.to_tensor(
  581. num_total_samples, dtype='float32', stop_gradient=True)
  582. odm_cls = F.sigmoid_focal_loss(
  583. odm_cls_score1,
  584. feat_labels_one_hot,
  585. normalizer=num_total_samples,
  586. reduction='none')
  587. feat_label_weights = feat_label_weights.reshape(
  588. feat_label_weights.shape[0], 1)
  589. feat_label_weights = np.repeat(
  590. feat_label_weights, self.cls_out_channels, axis=1)
  591. feat_label_weights = paddle.to_tensor(feat_label_weights)
  592. feat_label_weights.stop_gradient = True
  593. odm_cls = odm_cls * feat_label_weights
  594. odm_cls_total = paddle.sum(odm_cls)
  595. odm_cls_losses.append(odm_cls_total)
  596. # # step3: regression loss
  597. feat_bbox_targets = paddle.to_tensor(
  598. feat_bbox_targets, dtype='float32')
  599. feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
  600. feat_bbox_targets.stop_gradient = True
  601. odm_bbox_pred = odm_reg_branch_list[idx]
  602. odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0)
  603. odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5])
  604. odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets)
  605. loss_weight = paddle.to_tensor(
  606. self.reg_loss_weight, dtype='float32', stop_gradient=True)
  607. odm_bbox = paddle.multiply(odm_bbox, loss_weight)
  608. feat_bbox_weights = paddle.to_tensor(
  609. feat_bbox_weights, stop_gradient=True)
  610. if reg_loss_type == 'l1':
  611. odm_bbox = odm_bbox * feat_bbox_weights
  612. odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples
  613. elif reg_loss_type == 'iou' or reg_loss_type == 'gwd':
  614. odm_bbox = paddle.sum(odm_bbox, axis=-1)
  615. feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1)
  616. try:
  617. from rbox_iou_ops import rbox_iou
  618. except Exception as e:
  619. print("import custom_ops error, try install rbox_iou_ops " \
  620. "following ppdet/ext_op/README.md", e)
  621. sys.stdout.flush()
  622. sys.exit(-1)
  623. # calc iou
  624. odm_bbox_decode = self.delta2rbox(self.refine_anchor_list[idx],
  625. odm_bbox_pred)
  626. bbox_gt_bboxes = paddle.to_tensor(
  627. bbox_gt_bboxes,
  628. dtype=odm_bbox_decode.dtype,
  629. place=odm_bbox_decode.place)
  630. bbox_gt_bboxes.stop_gradient = True
  631. iou = rbox_iou(odm_bbox_decode, bbox_gt_bboxes)
  632. iou = paddle.diag(iou)
  633. if reg_loss_type == 'gwd':
  634. bbox_gt_bboxes_level = bbox_gt_bboxes[st_idx:st_idx +
  635. feat_anchor_num, :]
  636. odm_bbox_total = self.gwd_loss(odm_bbox_decode,
  637. bbox_gt_bboxes_level)
  638. odm_bbox_total = odm_bbox_total * feat_bbox_weights
  639. odm_bbox_total = paddle.sum(
  640. odm_bbox_total) / num_total_samples
  641. odm_bbox_losses.append(odm_bbox_total)
  642. st_idx += feat_anchor_num
  643. odm_cls_loss = paddle.add_n(odm_cls_losses)
  644. odm_cls_loss_weight = paddle.to_tensor(
  645. self.cls_loss_weight[1], dtype='float32', stop_gradient=True)
  646. odm_cls_loss = odm_cls_loss * odm_cls_loss_weight
  647. odm_reg_loss = paddle.add_n(odm_bbox_losses)
  648. return odm_cls_loss, odm_reg_loss
  649. def get_loss(self, inputs):
  650. # inputs: im_id image im_shape scale_factor gt_bbox gt_class is_crowd
  651. # compute loss
  652. fam_cls_loss_lst = []
  653. fam_reg_loss_lst = []
  654. odm_cls_loss_lst = []
  655. odm_reg_loss_lst = []
  656. im_shape = inputs['im_shape']
  657. for im_id in range(im_shape.shape[0]):
  658. np_im_shape = inputs['im_shape'][im_id].numpy()
  659. np_scale_factor = inputs['scale_factor'][im_id].numpy()
  660. # data_format: (xc, yc, w, h, theta)
  661. gt_bboxes = inputs['gt_rbox'][im_id].numpy()
  662. gt_labels = inputs['gt_class'][im_id].numpy()
  663. is_crowd = inputs['is_crowd'][im_id].numpy()
  664. gt_labels = gt_labels + 1
  665. # featmap_sizes
  666. anchors_list_all = np.concatenate(self.base_anchors_list)
  667. # get im_feat
  668. fam_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[0]]
  669. fam_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[1]]
  670. odm_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[2]]
  671. odm_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[3]]
  672. im_s2anet_head_out = (fam_cls_feats_list, fam_reg_feats_list,
  673. odm_cls_feats_list, odm_reg_feats_list)
  674. # FAM
  675. im_fam_target = self.anchor_assign(anchors_list_all, gt_bboxes,
  676. gt_labels, is_crowd)
  677. if im_fam_target is not None:
  678. im_fam_cls_loss, im_fam_reg_loss = self.get_fam_loss(
  679. im_fam_target, im_s2anet_head_out, self.reg_loss_type)
  680. fam_cls_loss_lst.append(im_fam_cls_loss)
  681. fam_reg_loss_lst.append(im_fam_reg_loss)
  682. # ODM
  683. np_refine_anchors_list = paddle.concat(
  684. self.refine_anchor_list).numpy()
  685. np_refine_anchors_list = np.concatenate(np_refine_anchors_list)
  686. np_refine_anchors_list = np_refine_anchors_list.reshape(-1, 5)
  687. im_odm_target = self.anchor_assign(np_refine_anchors_list,
  688. gt_bboxes, gt_labels, is_crowd)
  689. if im_odm_target is not None:
  690. im_odm_cls_loss, im_odm_reg_loss = self.get_odm_loss(
  691. im_odm_target, im_s2anet_head_out, self.reg_loss_type)
  692. odm_cls_loss_lst.append(im_odm_cls_loss)
  693. odm_reg_loss_lst.append(im_odm_reg_loss)
  694. fam_cls_loss = paddle.add_n(fam_cls_loss_lst)
  695. fam_reg_loss = paddle.add_n(fam_reg_loss_lst)
  696. odm_cls_loss = paddle.add_n(odm_cls_loss_lst)
  697. odm_reg_loss = paddle.add_n(odm_reg_loss_lst)
  698. return {
  699. 'fam_cls_loss': fam_cls_loss,
  700. 'fam_reg_loss': fam_reg_loss,
  701. 'odm_cls_loss': odm_cls_loss,
  702. 'odm_reg_loss': odm_reg_loss
  703. }
  704. def get_bboxes(self, cls_score_list, bbox_pred_list, mlvl_anchors, nms_pre,
  705. cls_out_channels, use_sigmoid_cls):
  706. assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
  707. mlvl_bboxes = []
  708. mlvl_scores = []
  709. idx = 0
  710. for cls_score, bbox_pred, anchors in zip(cls_score_list, bbox_pred_list,
  711. mlvl_anchors):
  712. cls_score = paddle.reshape(cls_score, [-1, cls_out_channels])
  713. if use_sigmoid_cls:
  714. scores = F.sigmoid(cls_score)
  715. else:
  716. scores = F.softmax(cls_score, axis=-1)
  717. # bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 5)
  718. bbox_pred = paddle.transpose(bbox_pred, [1, 2, 0])
  719. bbox_pred = paddle.reshape(bbox_pred, [-1, 5])
  720. anchors = paddle.reshape(anchors, [-1, 5])
  721. if scores.shape[0] > nms_pre:
  722. # Get maximum scores for foreground classes.
  723. if use_sigmoid_cls:
  724. max_scores = paddle.max(scores, axis=1)
  725. else:
  726. max_scores = paddle.max(scores[:, 1:], axis=1)
  727. topk_val, topk_inds = paddle.topk(max_scores, nms_pre)
  728. anchors = paddle.gather(anchors, topk_inds)
  729. bbox_pred = paddle.gather(bbox_pred, topk_inds)
  730. scores = paddle.gather(scores, topk_inds)
  731. bbox_delta = paddle.reshape(bbox_pred, [-1, 5])
  732. bboxes = self.delta2rbox(anchors, bbox_delta)
  733. mlvl_bboxes.append(bboxes)
  734. mlvl_scores.append(scores)
  735. idx += 1
  736. mlvl_bboxes = paddle.concat(mlvl_bboxes, axis=0)
  737. mlvl_scores = paddle.concat(mlvl_scores)
  738. return mlvl_scores, mlvl_bboxes
  739. def rect2rbox(self, bboxes):
  740. """
  741. :param bboxes: shape (n, 4) (xmin, ymin, xmax, ymax)
  742. :return: dbboxes: shape (n, 5) (x_ctr, y_ctr, w, h, angle)
  743. """
  744. bboxes = paddle.reshape(bboxes, [-1, 4])
  745. num_boxes = paddle.shape(bboxes)[0]
  746. x_ctr = (bboxes[:, 2] + bboxes[:, 0]) / 2.0
  747. y_ctr = (bboxes[:, 3] + bboxes[:, 1]) / 2.0
  748. edges1 = paddle.abs(bboxes[:, 2] - bboxes[:, 0])
  749. edges2 = paddle.abs(bboxes[:, 3] - bboxes[:, 1])
  750. rbox_w = paddle.maximum(edges1, edges2)
  751. rbox_h = paddle.minimum(edges1, edges2)
  752. # set angle
  753. inds = edges1 < edges2
  754. inds = paddle.cast(inds, 'int32')
  755. rboxes_angle = inds * np.pi / 2.0
  756. rboxes = paddle.stack(
  757. (x_ctr, y_ctr, rbox_w, rbox_h, rboxes_angle), axis=1)
  758. return rboxes
  759. # deltas to rbox
  760. def delta2rbox(self, rrois, deltas, wh_ratio_clip=1e-6):
  761. """
  762. :param rrois: (cx, cy, w, h, theta)
  763. :param deltas: (dx, dy, dw, dh, dtheta)
  764. :param means: means of anchor
  765. :param stds: stds of anchor
  766. :param wh_ratio_clip: clip threshold of wh_ratio
  767. :return:
  768. """
  769. deltas = paddle.reshape(deltas, [-1, 5])
  770. rrois = paddle.reshape(rrois, [-1, 5])
  771. # fix dy2st bug denorm_deltas = deltas * self.stds + self.means
  772. denorm_deltas = paddle.add(
  773. paddle.multiply(deltas, self.stds), self.means)
  774. dx = denorm_deltas[:, 0]
  775. dy = denorm_deltas[:, 1]
  776. dw = denorm_deltas[:, 2]
  777. dh = denorm_deltas[:, 3]
  778. dangle = denorm_deltas[:, 4]
  779. max_ratio = np.abs(np.log(wh_ratio_clip))
  780. dw = paddle.clip(dw, min=-max_ratio, max=max_ratio)
  781. dh = paddle.clip(dh, min=-max_ratio, max=max_ratio)
  782. rroi_x = rrois[:, 0]
  783. rroi_y = rrois[:, 1]
  784. rroi_w = rrois[:, 2]
  785. rroi_h = rrois[:, 3]
  786. rroi_angle = rrois[:, 4]
  787. gx = dx * rroi_w * paddle.cos(rroi_angle) - dy * rroi_h * paddle.sin(
  788. rroi_angle) + rroi_x
  789. gy = dx * rroi_w * paddle.sin(rroi_angle) + dy * rroi_h * paddle.cos(
  790. rroi_angle) + rroi_y
  791. gw = rroi_w * dw.exp()
  792. gh = rroi_h * dh.exp()
  793. ga = np.pi * dangle + rroi_angle
  794. ga = (ga + np.pi / 4) % np.pi - np.pi / 4
  795. ga = paddle.to_tensor(ga)
  796. gw = paddle.to_tensor(gw, dtype='float32')
  797. gh = paddle.to_tensor(gh, dtype='float32')
  798. bboxes = paddle.stack([gx, gy, gw, gh, ga], axis=-1)
  799. return bboxes
  800. def bbox_decode(self, bbox_preds, anchors):
  801. """decode bbox from deltas
  802. Args:
  803. bbox_preds: [N,H,W,5]
  804. anchors: [H*W,5]
  805. return:
  806. bboxes: [N,H,W,5]
  807. """
  808. num_imgs, H, W, _ = bbox_preds.shape
  809. bbox_delta = paddle.reshape(bbox_preds, [-1, 5])
  810. bboxes = self.delta2rbox(anchors, bbox_delta)
  811. return bboxes
  812. def trace(self, A):
  813. tr = paddle.diagonal(A, axis1=-2, axis2=-1)
  814. tr = paddle.sum(tr, axis=-1)
  815. return tr
  816. def sqrt_newton_schulz_autograd(self, A, numIters):
  817. A_shape = A.shape
  818. batchSize = A_shape[0]
  819. dim = A_shape[1]
  820. normA = A * A
  821. normA = paddle.sum(normA, axis=1)
  822. normA = paddle.sum(normA, axis=1)
  823. normA = paddle.sqrt(normA)
  824. normA1 = normA.reshape([batchSize, 1, 1])
  825. Y = paddle.divide(A, paddle.expand_as(normA1, A))
  826. I = paddle.eye(dim, dim).reshape([1, dim, dim])
  827. l0 = []
  828. for i in range(batchSize):
  829. l0.append(I)
  830. I = paddle.concat(l0, axis=0)
  831. I.stop_gradient = False
  832. Z = paddle.eye(dim, dim).reshape([1, dim, dim])
  833. l1 = []
  834. for i in range(batchSize):
  835. l1.append(Z)
  836. Z = paddle.concat(l1, axis=0)
  837. Z.stop_gradient = False
  838. for i in range(numIters):
  839. T = 0.5 * (3.0 * I - Z.bmm(Y))
  840. Y = Y.bmm(T)
  841. Z = T.bmm(Z)
  842. sA = Y * paddle.sqrt(normA1).reshape([batchSize, 1, 1])
  843. sA = paddle.expand_as(sA, A)
  844. return sA
  845. def wasserstein_distance_sigma(sigma1, sigma2):
  846. wasserstein_distance_item2 = paddle.matmul(
  847. sigma1, sigma1) + paddle.matmul(
  848. sigma2, sigma2) - 2 * self.sqrt_newton_schulz_autograd(
  849. paddle.matmul(
  850. paddle.matmul(sigma1, paddle.matmul(sigma2, sigma2)),
  851. sigma1), 10)
  852. wasserstein_distance_item2 = self.trace(wasserstein_distance_item2)
  853. return wasserstein_distance_item2
  854. def xywhr2xyrs(self, xywhr):
  855. xywhr = paddle.reshape(xywhr, [-1, 5])
  856. xy = xywhr[:, :2]
  857. wh = paddle.clip(xywhr[:, 2:4], min=1e-7, max=1e7)
  858. r = xywhr[:, 4]
  859. cos_r = paddle.cos(r)
  860. sin_r = paddle.sin(r)
  861. R = paddle.stack(
  862. (cos_r, -sin_r, sin_r, cos_r), axis=-1).reshape([-1, 2, 2])
  863. S = 0.5 * paddle.nn.functional.diag_embed(wh)
  864. return xy, R, S
  865. def gwd_loss(self,
  866. pred,
  867. target,
  868. fun='log',
  869. tau=1.0,
  870. alpha=1.0,
  871. normalize=False):
  872. xy_p, R_p, S_p = self.xywhr2xyrs(pred)
  873. xy_t, R_t, S_t = self.xywhr2xyrs(target)
  874. xy_distance = (xy_p - xy_t).square().sum(axis=-1)
  875. Sigma_p = R_p.matmul(S_p.square()).matmul(R_p.transpose([0, 2, 1]))
  876. Sigma_t = R_t.matmul(S_t.square()).matmul(R_t.transpose([0, 2, 1]))
  877. whr_distance = paddle.diagonal(
  878. S_p, axis1=-2, axis2=-1).square().sum(axis=-1)
  879. whr_distance = whr_distance + paddle.diagonal(
  880. S_t, axis1=-2, axis2=-1).square().sum(axis=-1)
  881. _t = Sigma_p.matmul(Sigma_t)
  882. _t_tr = paddle.diagonal(_t, axis1=-2, axis2=-1).sum(axis=-1)
  883. _t_det_sqrt = paddle.diagonal(S_p, axis1=-2, axis2=-1).prod(axis=-1)
  884. _t_det_sqrt = _t_det_sqrt * paddle.diagonal(
  885. S_t, axis1=-2, axis2=-1).prod(axis=-1)
  886. whr_distance = whr_distance + (-2) * (
  887. (_t_tr + 2 * _t_det_sqrt).clip(0).sqrt())
  888. distance = (xy_distance + alpha * alpha * whr_distance).clip(0)
  889. if normalize:
  890. wh_p = pred[..., 2:4].clip(min=1e-7, max=1e7)
  891. wh_t = target[..., 2:4].clip(min=1e-7, max=1e7)
  892. scale = ((wh_p.log() + wh_t.log()).sum(dim=-1) / 4).exp()
  893. distance = distance / scale
  894. if fun == 'log':
  895. distance = paddle.log1p(distance)
  896. if tau >= 1.0:
  897. return 1 - 1 / (tau + distance)
  898. return distance