ppyoloe_head.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from ppdet.core.workspace import register
  18. from ..bbox_utils import batch_distance2bbox
  19. from ..losses import GIoULoss
  20. from ..initializer import bias_init_with_prob, constant_, normal_
  21. from ..assigners.utils import generate_anchors_for_grid_cell
  22. from ppdet.modeling.backbones.cspresnet import ConvBNLayer
  23. from ppdet.modeling.ops import get_static_shape, paddle_distributed_is_initialized, get_act_fn
  24. from ppdet.modeling.layers import MultiClassNMS
  25. __all__ = ['PPYOLOEHead']
  26. class ESEAttn(nn.Layer):
  27. def __init__(self, feat_channels, act='swish'):
  28. super(ESEAttn, self).__init__()
  29. self.fc = nn.Conv2D(feat_channels, feat_channels, 1)
  30. self.conv = ConvBNLayer(feat_channels, feat_channels, 1, act=act)
  31. self._init_weights()
  32. def _init_weights(self):
  33. normal_(self.fc.weight, std=0.001)
  34. def forward(self, feat, avg_feat):
  35. weight = F.sigmoid(self.fc(avg_feat))
  36. return self.conv(feat * weight)
  37. @register
  38. class PPYOLOEHead(nn.Layer):
  39. __shared__ = ['num_classes', 'eval_size', 'trt', 'exclude_nms']
  40. __inject__ = ['static_assigner', 'assigner', 'nms']
  41. def __init__(self,
  42. in_channels=[1024, 512, 256],
  43. num_classes=80,
  44. act='swish',
  45. fpn_strides=(32, 16, 8),
  46. grid_cell_scale=5.0,
  47. grid_cell_offset=0.5,
  48. reg_max=16,
  49. static_assigner_epoch=4,
  50. use_varifocal_loss=True,
  51. static_assigner='ATSSAssigner',
  52. assigner='TaskAlignedAssigner',
  53. nms='MultiClassNMS',
  54. eval_size=None,
  55. loss_weight={
  56. 'class': 1.0,
  57. 'iou': 2.5,
  58. 'dfl': 0.5,
  59. },
  60. trt=False,
  61. exclude_nms=False):
  62. super(PPYOLOEHead, self).__init__()
  63. assert len(in_channels) > 0, "len(in_channels) should > 0"
  64. self.in_channels = in_channels
  65. self.num_classes = num_classes
  66. self.fpn_strides = fpn_strides
  67. self.grid_cell_scale = grid_cell_scale
  68. self.grid_cell_offset = grid_cell_offset
  69. self.reg_max = reg_max
  70. self.iou_loss = GIoULoss()
  71. self.loss_weight = loss_weight
  72. self.use_varifocal_loss = use_varifocal_loss
  73. self.eval_size = eval_size
  74. self.static_assigner_epoch = static_assigner_epoch
  75. self.static_assigner = static_assigner
  76. self.assigner = assigner
  77. self.nms = nms
  78. if isinstance(self.nms, MultiClassNMS) and trt:
  79. self.nms.trt = trt
  80. self.exclude_nms = exclude_nms
  81. # stem
  82. self.stem_cls = nn.LayerList()
  83. self.stem_reg = nn.LayerList()
  84. act = get_act_fn(
  85. act, trt=trt) if act is None or isinstance(act,
  86. (str, dict)) else act
  87. for in_c in self.in_channels:
  88. self.stem_cls.append(ESEAttn(in_c, act=act))
  89. self.stem_reg.append(ESEAttn(in_c, act=act))
  90. # pred head
  91. self.pred_cls = nn.LayerList()
  92. self.pred_reg = nn.LayerList()
  93. for in_c in self.in_channels:
  94. self.pred_cls.append(
  95. nn.Conv2D(
  96. in_c, self.num_classes, 3, padding=1))
  97. self.pred_reg.append(
  98. nn.Conv2D(
  99. in_c, 4 * (self.reg_max + 1), 3, padding=1))
  100. # projection conv
  101. self.proj_conv = nn.Conv2D(self.reg_max + 1, 1, 1, bias_attr=False)
  102. self._init_weights()
  103. @classmethod
  104. def from_config(cls, cfg, input_shape):
  105. return {'in_channels': [i.channels for i in input_shape], }
  106. def _init_weights(self):
  107. bias_cls = bias_init_with_prob(0.01)
  108. for cls_, reg_ in zip(self.pred_cls, self.pred_reg):
  109. constant_(cls_.weight)
  110. constant_(cls_.bias, bias_cls)
  111. constant_(reg_.weight)
  112. constant_(reg_.bias, 1.0)
  113. self.proj = paddle.linspace(0, self.reg_max, self.reg_max + 1)
  114. self.proj_conv.weight.set_value(
  115. self.proj.reshape([1, self.reg_max + 1, 1, 1]))
  116. self.proj_conv.weight.stop_gradient = True
  117. if self.eval_size:
  118. anchor_points, stride_tensor = self._generate_anchors()
  119. self.anchor_points = anchor_points
  120. self.stride_tensor = stride_tensor
  121. def forward_train(self, feats, targets):
  122. anchors, anchor_points, num_anchors_list, stride_tensor = \
  123. generate_anchors_for_grid_cell(
  124. feats, self.fpn_strides, self.grid_cell_scale,
  125. self.grid_cell_offset)
  126. cls_score_list, reg_distri_list = [], []
  127. for i, feat in enumerate(feats):
  128. avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
  129. cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
  130. feat)
  131. reg_distri = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
  132. # cls and reg
  133. cls_score = F.sigmoid(cls_logit)
  134. cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
  135. reg_distri_list.append(reg_distri.flatten(2).transpose([0, 2, 1]))
  136. cls_score_list = paddle.concat(cls_score_list, axis=1)
  137. reg_distri_list = paddle.concat(reg_distri_list, axis=1)
  138. return self.get_loss([
  139. cls_score_list, reg_distri_list, anchors, anchor_points,
  140. num_anchors_list, stride_tensor
  141. ], targets)
  142. def _generate_anchors(self, feats=None):
  143. # just use in eval time
  144. anchor_points = []
  145. stride_tensor = []
  146. for i, stride in enumerate(self.fpn_strides):
  147. if feats is not None:
  148. _, _, h, w = feats[i].shape
  149. else:
  150. h = int(self.eval_size[0] / stride)
  151. w = int(self.eval_size[1] / stride)
  152. shift_x = paddle.arange(end=w) + self.grid_cell_offset
  153. shift_y = paddle.arange(end=h) + self.grid_cell_offset
  154. shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
  155. anchor_point = paddle.cast(
  156. paddle.stack(
  157. [shift_x, shift_y], axis=-1), dtype='float32')
  158. anchor_points.append(anchor_point.reshape([-1, 2]))
  159. stride_tensor.append(
  160. paddle.full(
  161. [h * w, 1], stride, dtype='float32'))
  162. anchor_points = paddle.concat(anchor_points)
  163. stride_tensor = paddle.concat(stride_tensor)
  164. return anchor_points, stride_tensor
  165. def forward_eval(self, feats):
  166. if self.eval_size:
  167. anchor_points, stride_tensor = self.anchor_points, self.stride_tensor
  168. else:
  169. anchor_points, stride_tensor = self._generate_anchors(feats)
  170. cls_score_list, reg_dist_list = [], []
  171. for i, feat in enumerate(feats):
  172. b, _, h, w = feat.shape
  173. l = h * w
  174. avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
  175. cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
  176. feat)
  177. reg_dist = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
  178. reg_dist = reg_dist.reshape([-1, 4, self.reg_max + 1, l]).transpose(
  179. [0, 2, 1, 3])
  180. reg_dist = self.proj_conv(F.softmax(reg_dist, axis=1))
  181. # cls and reg
  182. cls_score = F.sigmoid(cls_logit)
  183. cls_score_list.append(cls_score.reshape([b, self.num_classes, l]))
  184. reg_dist_list.append(reg_dist.reshape([b, 4, l]))
  185. cls_score_list = paddle.concat(cls_score_list, axis=-1)
  186. reg_dist_list = paddle.concat(reg_dist_list, axis=-1)
  187. return cls_score_list, reg_dist_list, anchor_points, stride_tensor
  188. def forward(self, feats, targets=None):
  189. assert len(feats) == len(self.fpn_strides), \
  190. "The size of feats is not equal to size of fpn_strides"
  191. if self.training:
  192. return self.forward_train(feats, targets)
  193. else:
  194. return self.forward_eval(feats)
  195. @staticmethod
  196. def _focal_loss(score, label, alpha=0.25, gamma=2.0):
  197. weight = (score - label).pow(gamma)
  198. if alpha > 0:
  199. alpha_t = alpha * label + (1 - alpha) * (1 - label)
  200. weight *= alpha_t
  201. loss = F.binary_cross_entropy(
  202. score, label, weight=weight, reduction='sum')
  203. return loss
  204. @staticmethod
  205. def _varifocal_loss(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
  206. weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
  207. loss = F.binary_cross_entropy(
  208. pred_score, gt_score, weight=weight, reduction='sum')
  209. return loss
  210. def _bbox_decode(self, anchor_points, pred_dist):
  211. b, l, _ = get_static_shape(pred_dist)
  212. pred_dist = F.softmax(pred_dist.reshape([b, l, 4, self.reg_max + 1
  213. ])).matmul(self.proj)
  214. return batch_distance2bbox(anchor_points, pred_dist)
  215. def _bbox2distance(self, points, bbox):
  216. x1y1, x2y2 = paddle.split(bbox, 2, -1)
  217. lt = points - x1y1
  218. rb = x2y2 - points
  219. return paddle.concat([lt, rb], -1).clip(0, self.reg_max - 0.01)
  220. def _df_loss(self, pred_dist, target):
  221. target_left = paddle.cast(target, 'int64')
  222. target_right = target_left + 1
  223. weight_left = target_right.astype('float32') - target
  224. weight_right = 1 - weight_left
  225. loss_left = F.cross_entropy(
  226. pred_dist, target_left, reduction='none') * weight_left
  227. loss_right = F.cross_entropy(
  228. pred_dist, target_right, reduction='none') * weight_right
  229. return (loss_left + loss_right).mean(-1, keepdim=True)
  230. def _bbox_loss(self, pred_dist, pred_bboxes, anchor_points, assigned_labels,
  231. assigned_bboxes, assigned_scores, assigned_scores_sum):
  232. # select positive samples mask
  233. mask_positive = (assigned_labels != self.num_classes)
  234. num_pos = mask_positive.sum()
  235. # pos/neg loss
  236. if num_pos > 0:
  237. # l1 + iou
  238. bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4])
  239. pred_bboxes_pos = paddle.masked_select(pred_bboxes,
  240. bbox_mask).reshape([-1, 4])
  241. assigned_bboxes_pos = paddle.masked_select(
  242. assigned_bboxes, bbox_mask).reshape([-1, 4])
  243. bbox_weight = paddle.masked_select(
  244. assigned_scores.sum(-1), mask_positive).unsqueeze(-1)
  245. loss_l1 = F.l1_loss(pred_bboxes_pos, assigned_bboxes_pos)
  246. loss_iou = self.iou_loss(pred_bboxes_pos,
  247. assigned_bboxes_pos) * bbox_weight
  248. loss_iou = loss_iou.sum() / assigned_scores_sum
  249. dist_mask = mask_positive.unsqueeze(-1).tile(
  250. [1, 1, (self.reg_max + 1) * 4])
  251. pred_dist_pos = paddle.masked_select(
  252. pred_dist, dist_mask).reshape([-1, 4, self.reg_max + 1])
  253. assigned_ltrb = self._bbox2distance(anchor_points, assigned_bboxes)
  254. assigned_ltrb_pos = paddle.masked_select(
  255. assigned_ltrb, bbox_mask).reshape([-1, 4])
  256. loss_dfl = self._df_loss(pred_dist_pos,
  257. assigned_ltrb_pos) * bbox_weight
  258. loss_dfl = loss_dfl.sum() / assigned_scores_sum
  259. else:
  260. loss_l1 = paddle.zeros([1])
  261. loss_iou = paddle.zeros([1])
  262. loss_dfl = pred_dist.sum() * 0.
  263. return loss_l1, loss_iou, loss_dfl
  264. def get_loss(self, head_outs, gt_meta):
  265. pred_scores, pred_distri, anchors,\
  266. anchor_points, num_anchors_list, stride_tensor = head_outs
  267. anchor_points_s = anchor_points / stride_tensor
  268. pred_bboxes = self._bbox_decode(anchor_points_s, pred_distri)
  269. gt_labels = gt_meta['gt_class']
  270. gt_bboxes = gt_meta['gt_bbox']
  271. pad_gt_mask = gt_meta['pad_gt_mask']
  272. # label assignment
  273. if gt_meta['epoch_id'] < self.static_assigner_epoch:
  274. assigned_labels, assigned_bboxes, assigned_scores = \
  275. self.static_assigner(
  276. anchors,
  277. num_anchors_list,
  278. gt_labels,
  279. gt_bboxes,
  280. pad_gt_mask,
  281. bg_index=self.num_classes,
  282. pred_bboxes=pred_bboxes.detach() * stride_tensor)
  283. alpha_l = 0.25
  284. else:
  285. assigned_labels, assigned_bboxes, assigned_scores = \
  286. self.assigner(
  287. pred_scores.detach(),
  288. pred_bboxes.detach() * stride_tensor,
  289. anchor_points,
  290. num_anchors_list,
  291. gt_labels,
  292. gt_bboxes,
  293. pad_gt_mask,
  294. bg_index=self.num_classes)
  295. alpha_l = -1
  296. # rescale bbox
  297. assigned_bboxes /= stride_tensor
  298. # cls loss
  299. if self.use_varifocal_loss:
  300. one_hot_label = F.one_hot(assigned_labels,
  301. self.num_classes + 1)[..., :-1]
  302. loss_cls = self._varifocal_loss(pred_scores, assigned_scores,
  303. one_hot_label)
  304. else:
  305. loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha_l)
  306. assigned_scores_sum = assigned_scores.sum()
  307. if paddle_distributed_is_initialized():
  308. paddle.distributed.all_reduce(assigned_scores_sum)
  309. assigned_scores_sum = paddle.clip(
  310. assigned_scores_sum / paddle.distributed.get_world_size(),
  311. min=1)
  312. loss_cls /= assigned_scores_sum
  313. loss_l1, loss_iou, loss_dfl = \
  314. self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s,
  315. assigned_labels, assigned_bboxes, assigned_scores,
  316. assigned_scores_sum)
  317. loss = self.loss_weight['class'] * loss_cls + \
  318. self.loss_weight['iou'] * loss_iou + \
  319. self.loss_weight['dfl'] * loss_dfl
  320. out_dict = {
  321. 'loss': loss,
  322. 'loss_cls': loss_cls,
  323. 'loss_iou': loss_iou,
  324. 'loss_dfl': loss_dfl,
  325. 'loss_l1': loss_l1,
  326. }
  327. return out_dict
  328. def post_process(self, head_outs, img_shape, scale_factor):
  329. pred_scores, pred_dist, anchor_points, stride_tensor = head_outs
  330. pred_bboxes = batch_distance2bbox(anchor_points,
  331. pred_dist.transpose([0, 2, 1]))
  332. pred_bboxes *= stride_tensor
  333. # scale bbox to origin
  334. scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
  335. scale_factor = paddle.concat(
  336. [scale_x, scale_y, scale_x, scale_y], axis=-1).reshape([-1, 1, 4])
  337. pred_bboxes /= scale_factor
  338. if self.exclude_nms:
  339. # `exclude_nms=True` just use in benchmark
  340. return pred_bboxes.sum(), pred_scores.sum()
  341. else:
  342. bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
  343. return bbox_pred, bbox_num