solov2_head.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import ParamAttr
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddle.nn.initializer import Normal, Constant
  22. from ppdet.modeling.layers import ConvNormLayer, MaskMatrixNMS, DropBlock
  23. from ppdet.core.workspace import register
  24. from six.moves import zip
  25. import numpy as np
  26. __all__ = ['SOLOv2Head']
  27. @register
  28. class SOLOv2MaskHead(nn.Layer):
  29. """
  30. MaskHead of SOLOv2.
  31. The code of this function is based on:
  32. https://github.com/WXinlong/SOLO/blob/master/mmdet/models/mask_heads/mask_feat_head.py
  33. Args:
  34. in_channels (int): The channel number of input Tensor.
  35. out_channels (int): The channel number of output Tensor.
  36. start_level (int): The position where the input starts.
  37. end_level (int): The position where the input ends.
  38. use_dcn_in_tower (bool): Whether to use dcn in tower or not.
  39. """
  40. __shared__ = ['norm_type']
  41. def __init__(self,
  42. in_channels=256,
  43. mid_channels=128,
  44. out_channels=256,
  45. start_level=0,
  46. end_level=3,
  47. use_dcn_in_tower=False,
  48. norm_type='gn'):
  49. super(SOLOv2MaskHead, self).__init__()
  50. assert start_level >= 0 and end_level >= start_level
  51. self.in_channels = in_channels
  52. self.out_channels = out_channels
  53. self.mid_channels = mid_channels
  54. self.use_dcn_in_tower = use_dcn_in_tower
  55. self.range_level = end_level - start_level + 1
  56. self.use_dcn = True if self.use_dcn_in_tower else False
  57. self.convs_all_levels = []
  58. self.norm_type = norm_type
  59. for i in range(start_level, end_level + 1):
  60. conv_feat_name = 'mask_feat_head.convs_all_levels.{}'.format(i)
  61. conv_pre_feat = nn.Sequential()
  62. if i == start_level:
  63. conv_pre_feat.add_sublayer(
  64. conv_feat_name + '.conv' + str(i),
  65. ConvNormLayer(
  66. ch_in=self.in_channels,
  67. ch_out=self.mid_channels,
  68. filter_size=3,
  69. stride=1,
  70. use_dcn=self.use_dcn,
  71. norm_type=self.norm_type))
  72. self.add_sublayer('conv_pre_feat' + str(i), conv_pre_feat)
  73. self.convs_all_levels.append(conv_pre_feat)
  74. else:
  75. for j in range(i):
  76. ch_in = 0
  77. if j == 0:
  78. ch_in = self.in_channels + 2 if i == end_level else self.in_channels
  79. else:
  80. ch_in = self.mid_channels
  81. conv_pre_feat.add_sublayer(
  82. conv_feat_name + '.conv' + str(j),
  83. ConvNormLayer(
  84. ch_in=ch_in,
  85. ch_out=self.mid_channels,
  86. filter_size=3,
  87. stride=1,
  88. use_dcn=self.use_dcn,
  89. norm_type=self.norm_type))
  90. conv_pre_feat.add_sublayer(
  91. conv_feat_name + '.conv' + str(j) + 'act', nn.ReLU())
  92. conv_pre_feat.add_sublayer(
  93. 'upsample' + str(i) + str(j),
  94. nn.Upsample(
  95. scale_factor=2, mode='bilinear'))
  96. self.add_sublayer('conv_pre_feat' + str(i), conv_pre_feat)
  97. self.convs_all_levels.append(conv_pre_feat)
  98. conv_pred_name = 'mask_feat_head.conv_pred.0'
  99. self.conv_pred = self.add_sublayer(
  100. conv_pred_name,
  101. ConvNormLayer(
  102. ch_in=self.mid_channels,
  103. ch_out=self.out_channels,
  104. filter_size=1,
  105. stride=1,
  106. use_dcn=self.use_dcn,
  107. norm_type=self.norm_type))
  108. def forward(self, inputs):
  109. """
  110. Get SOLOv2MaskHead output.
  111. Args:
  112. inputs(list[Tensor]): feature map from each necks with shape of [N, C, H, W]
  113. Returns:
  114. ins_pred(Tensor): Output of SOLOv2MaskHead head
  115. """
  116. feat_all_level = F.relu(self.convs_all_levels[0](inputs[0]))
  117. for i in range(1, self.range_level):
  118. input_p = inputs[i]
  119. if i == (self.range_level - 1):
  120. input_feat = input_p
  121. x_range = paddle.linspace(
  122. -1, 1, paddle.shape(input_feat)[-1], dtype='float32')
  123. y_range = paddle.linspace(
  124. -1, 1, paddle.shape(input_feat)[-2], dtype='float32')
  125. y, x = paddle.meshgrid([y_range, x_range])
  126. x = paddle.unsqueeze(x, [0, 1])
  127. y = paddle.unsqueeze(y, [0, 1])
  128. y = paddle.expand(
  129. y, shape=[paddle.shape(input_feat)[0], 1, -1, -1])
  130. x = paddle.expand(
  131. x, shape=[paddle.shape(input_feat)[0], 1, -1, -1])
  132. coord_feat = paddle.concat([x, y], axis=1)
  133. input_p = paddle.concat([input_p, coord_feat], axis=1)
  134. feat_all_level = paddle.add(feat_all_level,
  135. self.convs_all_levels[i](input_p))
  136. ins_pred = F.relu(self.conv_pred(feat_all_level))
  137. return ins_pred
  138. @register
  139. class SOLOv2Head(nn.Layer):
  140. """
  141. Head block for SOLOv2 network
  142. Args:
  143. num_classes (int): Number of output classes.
  144. in_channels (int): Number of input channels.
  145. seg_feat_channels (int): Num_filters of kernel & categroy branch convolution operation.
  146. stacked_convs (int): Times of convolution operation.
  147. num_grids (list[int]): List of feature map grids size.
  148. kernel_out_channels (int): Number of output channels in kernel branch.
  149. dcn_v2_stages (list): Which stage use dcn v2 in tower. It is between [0, stacked_convs).
  150. segm_strides (list[int]): List of segmentation area stride.
  151. solov2_loss (object): SOLOv2Loss instance.
  152. score_threshold (float): Threshold of categroy score.
  153. mask_nms (object): MaskMatrixNMS instance.
  154. """
  155. __inject__ = ['solov2_loss', 'mask_nms']
  156. __shared__ = ['norm_type', 'num_classes']
  157. def __init__(self,
  158. num_classes=80,
  159. in_channels=256,
  160. seg_feat_channels=256,
  161. stacked_convs=4,
  162. num_grids=[40, 36, 24, 16, 12],
  163. kernel_out_channels=256,
  164. dcn_v2_stages=[],
  165. segm_strides=[8, 8, 16, 32, 32],
  166. solov2_loss=None,
  167. score_threshold=0.1,
  168. mask_threshold=0.5,
  169. mask_nms=None,
  170. norm_type='gn',
  171. drop_block=False):
  172. super(SOLOv2Head, self).__init__()
  173. self.num_classes = num_classes
  174. self.in_channels = in_channels
  175. self.seg_num_grids = num_grids
  176. self.cate_out_channels = self.num_classes
  177. self.seg_feat_channels = seg_feat_channels
  178. self.stacked_convs = stacked_convs
  179. self.kernel_out_channels = kernel_out_channels
  180. self.dcn_v2_stages = dcn_v2_stages
  181. self.segm_strides = segm_strides
  182. self.solov2_loss = solov2_loss
  183. self.mask_nms = mask_nms
  184. self.score_threshold = score_threshold
  185. self.mask_threshold = mask_threshold
  186. self.norm_type = norm_type
  187. self.drop_block = drop_block
  188. self.kernel_pred_convs = []
  189. self.cate_pred_convs = []
  190. for i in range(self.stacked_convs):
  191. use_dcn = True if i in self.dcn_v2_stages else False
  192. ch_in = self.in_channels + 2 if i == 0 else self.seg_feat_channels
  193. kernel_conv = self.add_sublayer(
  194. 'bbox_head.kernel_convs.' + str(i),
  195. ConvNormLayer(
  196. ch_in=ch_in,
  197. ch_out=self.seg_feat_channels,
  198. filter_size=3,
  199. stride=1,
  200. use_dcn=use_dcn,
  201. norm_type=self.norm_type))
  202. self.kernel_pred_convs.append(kernel_conv)
  203. ch_in = self.in_channels if i == 0 else self.seg_feat_channels
  204. cate_conv = self.add_sublayer(
  205. 'bbox_head.cate_convs.' + str(i),
  206. ConvNormLayer(
  207. ch_in=ch_in,
  208. ch_out=self.seg_feat_channels,
  209. filter_size=3,
  210. stride=1,
  211. use_dcn=use_dcn,
  212. norm_type=self.norm_type))
  213. self.cate_pred_convs.append(cate_conv)
  214. self.solo_kernel = self.add_sublayer(
  215. 'bbox_head.solo_kernel',
  216. nn.Conv2D(
  217. self.seg_feat_channels,
  218. self.kernel_out_channels,
  219. kernel_size=3,
  220. stride=1,
  221. padding=1,
  222. weight_attr=ParamAttr(initializer=Normal(
  223. mean=0., std=0.01)),
  224. bias_attr=True))
  225. self.solo_cate = self.add_sublayer(
  226. 'bbox_head.solo_cate',
  227. nn.Conv2D(
  228. self.seg_feat_channels,
  229. self.cate_out_channels,
  230. kernel_size=3,
  231. stride=1,
  232. padding=1,
  233. weight_attr=ParamAttr(initializer=Normal(
  234. mean=0., std=0.01)),
  235. bias_attr=ParamAttr(initializer=Constant(
  236. value=float(-np.log((1 - 0.01) / 0.01))))))
  237. if self.drop_block and self.training:
  238. self.drop_block_fun = DropBlock(
  239. block_size=3, keep_prob=0.9, name='solo_cate.dropblock')
  240. def _points_nms(self, heat, kernel_size=2):
  241. hmax = F.max_pool2d(heat, kernel_size=kernel_size, stride=1, padding=1)
  242. keep = paddle.cast((hmax[:, :, :-1, :-1] == heat), 'float32')
  243. return heat * keep
  244. def _split_feats(self, feats):
  245. return (F.interpolate(
  246. feats[0],
  247. scale_factor=0.5,
  248. align_corners=False,
  249. align_mode=0,
  250. mode='bilinear'), feats[1], feats[2], feats[3], F.interpolate(
  251. feats[4],
  252. size=paddle.shape(feats[3])[-2:],
  253. mode='bilinear',
  254. align_corners=False,
  255. align_mode=0))
  256. def forward(self, input):
  257. """
  258. Get SOLOv2 head output
  259. Args:
  260. input (list): List of Tensors, output of backbone or neck stages
  261. Returns:
  262. cate_pred_list (list): Tensors of each category branch layer
  263. kernel_pred_list (list): Tensors of each kernel branch layer
  264. """
  265. feats = self._split_feats(input)
  266. cate_pred_list = []
  267. kernel_pred_list = []
  268. for idx in range(len(self.seg_num_grids)):
  269. cate_pred, kernel_pred = self._get_output_single(feats[idx], idx)
  270. cate_pred_list.append(cate_pred)
  271. kernel_pred_list.append(kernel_pred)
  272. return cate_pred_list, kernel_pred_list
  273. def _get_output_single(self, input, idx):
  274. ins_kernel_feat = input
  275. # CoordConv
  276. x_range = paddle.linspace(
  277. -1, 1, paddle.shape(ins_kernel_feat)[-1], dtype='float32')
  278. y_range = paddle.linspace(
  279. -1, 1, paddle.shape(ins_kernel_feat)[-2], dtype='float32')
  280. y, x = paddle.meshgrid([y_range, x_range])
  281. x = paddle.unsqueeze(x, [0, 1])
  282. y = paddle.unsqueeze(y, [0, 1])
  283. y = paddle.expand(
  284. y, shape=[paddle.shape(ins_kernel_feat)[0], 1, -1, -1])
  285. x = paddle.expand(
  286. x, shape=[paddle.shape(ins_kernel_feat)[0], 1, -1, -1])
  287. coord_feat = paddle.concat([x, y], axis=1)
  288. ins_kernel_feat = paddle.concat([ins_kernel_feat, coord_feat], axis=1)
  289. # kernel branch
  290. kernel_feat = ins_kernel_feat
  291. seg_num_grid = self.seg_num_grids[idx]
  292. kernel_feat = F.interpolate(
  293. kernel_feat,
  294. size=[seg_num_grid, seg_num_grid],
  295. mode='bilinear',
  296. align_corners=False,
  297. align_mode=0)
  298. cate_feat = kernel_feat[:, :-2, :, :]
  299. for kernel_layer in self.kernel_pred_convs:
  300. kernel_feat = F.relu(kernel_layer(kernel_feat))
  301. if self.drop_block and self.training:
  302. kernel_feat = self.drop_block_fun(kernel_feat)
  303. kernel_pred = self.solo_kernel(kernel_feat)
  304. # cate branch
  305. for cate_layer in self.cate_pred_convs:
  306. cate_feat = F.relu(cate_layer(cate_feat))
  307. if self.drop_block and self.training:
  308. cate_feat = self.drop_block_fun(cate_feat)
  309. cate_pred = self.solo_cate(cate_feat)
  310. if not self.training:
  311. cate_pred = self._points_nms(F.sigmoid(cate_pred), kernel_size=2)
  312. cate_pred = paddle.transpose(cate_pred, [0, 2, 3, 1])
  313. return cate_pred, kernel_pred
  314. def get_loss(self, cate_preds, kernel_preds, ins_pred, ins_labels,
  315. cate_labels, grid_order_list, fg_num):
  316. """
  317. Get loss of network of SOLOv2.
  318. Args:
  319. cate_preds (list): Tensor list of categroy branch output.
  320. kernel_preds (list): Tensor list of kernel branch output.
  321. ins_pred (list): Tensor list of instance branch output.
  322. ins_labels (list): List of instance labels pre batch.
  323. cate_labels (list): List of categroy labels pre batch.
  324. grid_order_list (list): List of index in pre grid.
  325. fg_num (int): Number of positive samples in a mini-batch.
  326. Returns:
  327. loss_ins (Tensor): The instance loss Tensor of SOLOv2 network.
  328. loss_cate (Tensor): The category loss Tensor of SOLOv2 network.
  329. """
  330. batch_size = paddle.shape(grid_order_list[0])[0]
  331. ins_pred_list = []
  332. for kernel_preds_level, grid_orders_level in zip(kernel_preds,
  333. grid_order_list):
  334. if grid_orders_level.shape[1] == 0:
  335. ins_pred_list.append(None)
  336. continue
  337. grid_orders_level = paddle.reshape(grid_orders_level, [-1])
  338. reshape_pred = paddle.reshape(
  339. kernel_preds_level,
  340. shape=(paddle.shape(kernel_preds_level)[0],
  341. paddle.shape(kernel_preds_level)[1], -1))
  342. reshape_pred = paddle.transpose(reshape_pred, [0, 2, 1])
  343. reshape_pred = paddle.reshape(
  344. reshape_pred, shape=(-1, paddle.shape(reshape_pred)[2]))
  345. gathered_pred = paddle.gather(reshape_pred, index=grid_orders_level)
  346. gathered_pred = paddle.reshape(
  347. gathered_pred,
  348. shape=[batch_size, -1, paddle.shape(gathered_pred)[1]])
  349. cur_ins_pred = ins_pred
  350. cur_ins_pred = paddle.reshape(
  351. cur_ins_pred,
  352. shape=(paddle.shape(cur_ins_pred)[0],
  353. paddle.shape(cur_ins_pred)[1], -1))
  354. ins_pred_conv = paddle.matmul(gathered_pred, cur_ins_pred)
  355. cur_ins_pred = paddle.reshape(
  356. ins_pred_conv,
  357. shape=(-1, paddle.shape(ins_pred)[-2],
  358. paddle.shape(ins_pred)[-1]))
  359. ins_pred_list.append(cur_ins_pred)
  360. num_ins = paddle.sum(fg_num)
  361. cate_preds = [
  362. paddle.reshape(
  363. paddle.transpose(cate_pred, [0, 2, 3, 1]),
  364. shape=(-1, self.cate_out_channels)) for cate_pred in cate_preds
  365. ]
  366. flatten_cate_preds = paddle.concat(cate_preds)
  367. new_cate_labels = []
  368. for cate_label in cate_labels:
  369. new_cate_labels.append(paddle.reshape(cate_label, shape=[-1]))
  370. cate_labels = paddle.concat(new_cate_labels)
  371. loss_ins, loss_cate = self.solov2_loss(
  372. ins_pred_list, ins_labels, flatten_cate_preds, cate_labels, num_ins)
  373. return {'loss_ins': loss_ins, 'loss_cate': loss_cate}
  374. def get_prediction(self, cate_preds, kernel_preds, seg_pred, im_shape,
  375. scale_factor):
  376. """
  377. Get prediction result of SOLOv2 network
  378. Args:
  379. cate_preds (list): List of Variables, output of categroy branch.
  380. kernel_preds (list): List of Variables, output of kernel branch.
  381. seg_pred (list): List of Variables, output of mask head stages.
  382. im_shape (Variables): [h, w] for input images.
  383. scale_factor (Variables): [scale, scale] for input images.
  384. Returns:
  385. seg_masks (Tensor): The prediction segmentation.
  386. cate_labels (Tensor): The prediction categroy label of each segmentation.
  387. seg_masks (Tensor): The prediction score of each segmentation.
  388. """
  389. num_levels = len(cate_preds)
  390. featmap_size = paddle.shape(seg_pred)[-2:]
  391. seg_masks_list = []
  392. cate_labels_list = []
  393. cate_scores_list = []
  394. cate_preds = [cate_pred * 1.0 for cate_pred in cate_preds]
  395. kernel_preds = [kernel_pred * 1.0 for kernel_pred in kernel_preds]
  396. # Currently only supports batch size == 1
  397. for idx in range(1):
  398. cate_pred_list = [
  399. paddle.reshape(
  400. cate_preds[i][idx], shape=(-1, self.cate_out_channels))
  401. for i in range(num_levels)
  402. ]
  403. seg_pred_list = seg_pred
  404. kernel_pred_list = [
  405. paddle.reshape(
  406. paddle.transpose(kernel_preds[i][idx], [1, 2, 0]),
  407. shape=(-1, self.kernel_out_channels))
  408. for i in range(num_levels)
  409. ]
  410. cate_pred_list = paddle.concat(cate_pred_list, axis=0)
  411. kernel_pred_list = paddle.concat(kernel_pred_list, axis=0)
  412. seg_masks, cate_labels, cate_scores = self.get_seg_single(
  413. cate_pred_list, seg_pred_list, kernel_pred_list, featmap_size,
  414. im_shape[idx], scale_factor[idx][0])
  415. bbox_num = paddle.shape(cate_labels)[0]
  416. return seg_masks, cate_labels, cate_scores, bbox_num
  417. def get_seg_single(self, cate_preds, seg_preds, kernel_preds, featmap_size,
  418. im_shape, scale_factor):
  419. """
  420. The code of this function is based on:
  421. https://github.com/WXinlong/SOLO/blob/master/mmdet/models/anchor_heads/solov2_head.py#L385
  422. """
  423. h = paddle.cast(im_shape[0], 'int32')[0]
  424. w = paddle.cast(im_shape[1], 'int32')[0]
  425. upsampled_size_out = [featmap_size[0] * 4, featmap_size[1] * 4]
  426. y = paddle.zeros(shape=paddle.shape(cate_preds), dtype='float32')
  427. inds = paddle.where(cate_preds > self.score_threshold, cate_preds, y)
  428. inds = paddle.nonzero(inds)
  429. cate_preds = paddle.reshape(cate_preds, shape=[-1])
  430. # Prevent empty and increase fake data
  431. ind_a = paddle.cast(paddle.shape(kernel_preds)[0], 'int64')
  432. ind_b = paddle.zeros(shape=[1], dtype='int64')
  433. inds_end = paddle.unsqueeze(paddle.concat([ind_a, ind_b]), 0)
  434. inds = paddle.concat([inds, inds_end])
  435. kernel_preds_end = paddle.ones(
  436. shape=[1, self.kernel_out_channels], dtype='float32')
  437. kernel_preds = paddle.concat([kernel_preds, kernel_preds_end])
  438. cate_preds = paddle.concat(
  439. [cate_preds, paddle.zeros(
  440. shape=[1], dtype='float32')])
  441. # cate_labels & kernel_preds
  442. cate_labels = inds[:, 1]
  443. kernel_preds = paddle.gather(kernel_preds, index=inds[:, 0])
  444. cate_score_idx = paddle.add(inds[:, 0] * self.cate_out_channels,
  445. cate_labels)
  446. cate_scores = paddle.gather(cate_preds, index=cate_score_idx)
  447. size_trans = np.power(self.seg_num_grids, 2)
  448. strides = []
  449. for _ind in range(len(self.segm_strides)):
  450. strides.append(
  451. paddle.full(
  452. shape=[int(size_trans[_ind])],
  453. fill_value=self.segm_strides[_ind],
  454. dtype="int32"))
  455. strides = paddle.concat(strides)
  456. strides = paddle.concat(
  457. [strides, paddle.zeros(
  458. shape=[1], dtype='int32')])
  459. strides = paddle.gather(strides, index=inds[:, 0])
  460. # mask encoding.
  461. kernel_preds = paddle.unsqueeze(kernel_preds, [2, 3])
  462. seg_preds = F.conv2d(seg_preds, kernel_preds)
  463. seg_preds = F.sigmoid(paddle.squeeze(seg_preds, [0]))
  464. seg_masks = seg_preds > self.mask_threshold
  465. seg_masks = paddle.cast(seg_masks, 'float32')
  466. sum_masks = paddle.sum(seg_masks, axis=[1, 2])
  467. y = paddle.zeros(shape=paddle.shape(sum_masks), dtype='float32')
  468. keep = paddle.where(sum_masks > strides, sum_masks, y)
  469. keep = paddle.nonzero(keep)
  470. keep = paddle.squeeze(keep, axis=[1])
  471. # Prevent empty and increase fake data
  472. keep_other = paddle.concat(
  473. [keep, paddle.cast(paddle.shape(sum_masks)[0] - 1, 'int64')])
  474. keep_scores = paddle.concat(
  475. [keep, paddle.cast(paddle.shape(sum_masks)[0], 'int64')])
  476. cate_scores_end = paddle.zeros(shape=[1], dtype='float32')
  477. cate_scores = paddle.concat([cate_scores, cate_scores_end])
  478. seg_masks = paddle.gather(seg_masks, index=keep_other)
  479. seg_preds = paddle.gather(seg_preds, index=keep_other)
  480. sum_masks = paddle.gather(sum_masks, index=keep_other)
  481. cate_labels = paddle.gather(cate_labels, index=keep_other)
  482. cate_scores = paddle.gather(cate_scores, index=keep_scores)
  483. # mask scoring.
  484. seg_mul = paddle.cast(seg_preds * seg_masks, 'float32')
  485. seg_scores = paddle.sum(seg_mul, axis=[1, 2]) / sum_masks
  486. cate_scores *= seg_scores
  487. # Matrix NMS
  488. seg_preds, cate_scores, cate_labels = self.mask_nms(
  489. seg_preds, seg_masks, cate_labels, cate_scores, sum_masks=sum_masks)
  490. ori_shape = im_shape[:2] / scale_factor + 0.5
  491. ori_shape = paddle.cast(ori_shape, 'int32')
  492. seg_preds = F.interpolate(
  493. paddle.unsqueeze(seg_preds, 0),
  494. size=upsampled_size_out,
  495. mode='bilinear',
  496. align_corners=False,
  497. align_mode=0)
  498. seg_preds = paddle.slice(
  499. seg_preds, axes=[2, 3], starts=[0, 0], ends=[h, w])
  500. seg_masks = paddle.squeeze(
  501. F.interpolate(
  502. seg_preds,
  503. size=ori_shape[:2],
  504. mode='bilinear',
  505. align_corners=False,
  506. align_mode=0),
  507. axis=[0])
  508. seg_masks = paddle.cast(seg_masks > self.mask_threshold, 'uint8')
  509. return seg_masks, cate_labels, cate_scores