yolo_head.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle import ParamAttr
  18. from paddle.regularizer import L2Decay
  19. from ppdet.core.workspace import register
  20. import math
  21. import numpy as np
  22. from ..initializer import bias_init_with_prob, constant_
  23. from ..backbones.csp_darknet import BaseConv, DWConv
  24. from ..losses import IouLoss
  25. from ppdet.modeling.assigners.simota_assigner import SimOTAAssigner
  26. from ppdet.modeling.bbox_utils import bbox_overlaps
  27. from ppdet.modeling.layers import MultiClassNMS
  28. __all__ = ['YOLOv3Head', 'YOLOXHead']
  29. def _de_sigmoid(x, eps=1e-7):
  30. x = paddle.clip(x, eps, 1. / eps)
  31. x = paddle.clip(1. / x - 1., eps, 1. / eps)
  32. x = -paddle.log(x)
  33. return x
  34. @register
  35. class YOLOv3Head(nn.Layer):
  36. __shared__ = ['num_classes', 'data_format']
  37. __inject__ = ['loss']
  38. def __init__(self,
  39. in_channels=[1024, 512, 256],
  40. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  41. [59, 119], [116, 90], [156, 198], [373, 326]],
  42. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  43. num_classes=80,
  44. loss='YOLOv3Loss',
  45. iou_aware=False,
  46. iou_aware_factor=0.4,
  47. data_format='NCHW'):
  48. """
  49. Head for YOLOv3 network
  50. Args:
  51. num_classes (int): number of foreground classes
  52. anchors (list): anchors
  53. anchor_masks (list): anchor masks
  54. loss (object): YOLOv3Loss instance
  55. iou_aware (bool): whether to use iou_aware
  56. iou_aware_factor (float): iou aware factor
  57. data_format (str): data format, NCHW or NHWC
  58. """
  59. super(YOLOv3Head, self).__init__()
  60. assert len(in_channels) > 0, "in_channels length should > 0"
  61. self.in_channels = in_channels
  62. self.num_classes = num_classes
  63. self.loss = loss
  64. self.iou_aware = iou_aware
  65. self.iou_aware_factor = iou_aware_factor
  66. self.parse_anchor(anchors, anchor_masks)
  67. self.num_outputs = len(self.anchors)
  68. self.data_format = data_format
  69. self.yolo_outputs = []
  70. for i in range(len(self.anchors)):
  71. if self.iou_aware:
  72. num_filters = len(self.anchors[i]) * (self.num_classes + 6)
  73. else:
  74. num_filters = len(self.anchors[i]) * (self.num_classes + 5)
  75. name = 'yolo_output.{}'.format(i)
  76. conv = nn.Conv2D(
  77. in_channels=self.in_channels[i],
  78. out_channels=num_filters,
  79. kernel_size=1,
  80. stride=1,
  81. padding=0,
  82. data_format=data_format,
  83. bias_attr=ParamAttr(regularizer=L2Decay(0.)))
  84. conv.skip_quant = True
  85. yolo_output = self.add_sublayer(name, conv)
  86. self.yolo_outputs.append(yolo_output)
  87. def parse_anchor(self, anchors, anchor_masks):
  88. self.anchors = [[anchors[i] for i in mask] for mask in anchor_masks]
  89. self.mask_anchors = []
  90. anchor_num = len(anchors)
  91. for masks in anchor_masks:
  92. self.mask_anchors.append([])
  93. for mask in masks:
  94. assert mask < anchor_num, "anchor mask index overflow"
  95. self.mask_anchors[-1].extend(anchors[mask])
  96. def forward(self, feats, targets=None):
  97. assert len(feats) == len(self.anchors)
  98. yolo_outputs = []
  99. for i, feat in enumerate(feats):
  100. yolo_output = self.yolo_outputs[i](feat)
  101. if self.data_format == 'NHWC':
  102. yolo_output = paddle.transpose(yolo_output, [0, 3, 1, 2])
  103. yolo_outputs.append(yolo_output)
  104. if self.training:
  105. return self.loss(yolo_outputs, targets, self.anchors)
  106. else:
  107. if self.iou_aware:
  108. y = []
  109. for i, out in enumerate(yolo_outputs):
  110. na = len(self.anchors[i])
  111. ioup, x = out[:, 0:na, :, :], out[:, na:, :, :]
  112. b, c, h, w = x.shape
  113. no = c // na
  114. x = x.reshape((b, na, no, h * w))
  115. ioup = ioup.reshape((b, na, 1, h * w))
  116. obj = x[:, :, 4:5, :]
  117. ioup = F.sigmoid(ioup)
  118. obj = F.sigmoid(obj)
  119. obj_t = (obj**(1 - self.iou_aware_factor)) * (
  120. ioup**self.iou_aware_factor)
  121. obj_t = _de_sigmoid(obj_t)
  122. loc_t = x[:, :, :4, :]
  123. cls_t = x[:, :, 5:, :]
  124. y_t = paddle.concat([loc_t, obj_t, cls_t], axis=2)
  125. y_t = y_t.reshape((b, c, h, w))
  126. y.append(y_t)
  127. return y
  128. else:
  129. return yolo_outputs
  130. @classmethod
  131. def from_config(cls, cfg, input_shape):
  132. return {'in_channels': [i.channels for i in input_shape], }
  133. @register
  134. class YOLOXHead(nn.Layer):
  135. __shared__ = ['num_classes', 'width_mult', 'act', 'trt', 'exclude_nms']
  136. __inject__ = ['assigner', 'nms']
  137. def __init__(self,
  138. num_classes=80,
  139. width_mult=1.0,
  140. depthwise=False,
  141. in_channels=[256, 512, 1024],
  142. feat_channels=256,
  143. fpn_strides=(8, 16, 32),
  144. l1_epoch=285,
  145. act='silu',
  146. assigner=SimOTAAssigner(use_vfl=False),
  147. nms='MultiClassNMS',
  148. loss_weight={
  149. 'cls': 1.0,
  150. 'obj': 1.0,
  151. 'iou': 5.0,
  152. 'l1': 1.0,
  153. },
  154. trt=False,
  155. exclude_nms=False):
  156. super(YOLOXHead, self).__init__()
  157. self._dtype = paddle.framework.get_default_dtype()
  158. self.num_classes = num_classes
  159. assert len(in_channels) > 0, "in_channels length should > 0"
  160. self.in_channels = in_channels
  161. feat_channels = int(feat_channels * width_mult)
  162. self.fpn_strides = fpn_strides
  163. self.l1_epoch = l1_epoch
  164. self.assigner = assigner
  165. self.nms = nms
  166. if isinstance(self.nms, MultiClassNMS) and trt:
  167. self.nms.trt = trt
  168. self.exclude_nms = exclude_nms
  169. self.loss_weight = loss_weight
  170. self.iou_loss = IouLoss(loss_weight=1.0) # default loss_weight 2.5
  171. ConvBlock = DWConv if depthwise else BaseConv
  172. self.stem_conv = nn.LayerList()
  173. self.conv_cls = nn.LayerList()
  174. self.conv_reg = nn.LayerList() # reg [x,y,w,h] + obj
  175. for in_c in self.in_channels:
  176. self.stem_conv.append(BaseConv(in_c, feat_channels, 1, 1, act=act))
  177. self.conv_cls.append(
  178. nn.Sequential(* [
  179. ConvBlock(
  180. feat_channels, feat_channels, 3, 1, act=act), ConvBlock(
  181. feat_channels, feat_channels, 3, 1, act=act),
  182. nn.Conv2D(
  183. feat_channels,
  184. self.num_classes,
  185. 1,
  186. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  187. ]))
  188. self.conv_reg.append(
  189. nn.Sequential(* [
  190. ConvBlock(
  191. feat_channels, feat_channels, 3, 1, act=act),
  192. ConvBlock(
  193. feat_channels, feat_channels, 3, 1, act=act),
  194. nn.Conv2D(
  195. feat_channels,
  196. 4 + 1, # reg [x,y,w,h] + obj
  197. 1,
  198. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  199. ]))
  200. self._init_weights()
  201. @classmethod
  202. def from_config(cls, cfg, input_shape):
  203. return {'in_channels': [i.channels for i in input_shape], }
  204. def _init_weights(self):
  205. bias_cls = bias_init_with_prob(0.01)
  206. bias_reg = paddle.full([5], math.log(5.), dtype=self._dtype)
  207. bias_reg[:2] = 0.
  208. bias_reg[-1] = bias_cls
  209. for cls_, reg_ in zip(self.conv_cls, self.conv_reg):
  210. constant_(cls_[-1].weight)
  211. constant_(cls_[-1].bias, bias_cls)
  212. constant_(reg_[-1].weight)
  213. reg_[-1].bias.set_value(bias_reg)
  214. def _generate_anchor_point(self, feat_sizes, strides, offset=0.):
  215. anchor_points, stride_tensor = [], []
  216. num_anchors_list = []
  217. for feat_size, stride in zip(feat_sizes, strides):
  218. h, w = feat_size
  219. x = (paddle.arange(w) + offset) * stride
  220. y = (paddle.arange(h) + offset) * stride
  221. y, x = paddle.meshgrid(y, x)
  222. anchor_points.append(paddle.stack([x, y], axis=-1).reshape([-1, 2]))
  223. stride_tensor.append(
  224. paddle.full(
  225. [len(anchor_points[-1]), 1], stride, dtype=self._dtype))
  226. num_anchors_list.append(len(anchor_points[-1]))
  227. anchor_points = paddle.concat(anchor_points).astype(self._dtype)
  228. anchor_points.stop_gradient = True
  229. stride_tensor = paddle.concat(stride_tensor)
  230. stride_tensor.stop_gradient = True
  231. return anchor_points, stride_tensor, num_anchors_list
  232. def forward(self, feats, targets=None):
  233. assert len(feats) == len(self.fpn_strides), \
  234. "The size of feats is not equal to size of fpn_strides"
  235. feat_sizes = [[f.shape[-2], f.shape[-1]] for f in feats]
  236. cls_score_list, reg_pred_list = [], []
  237. obj_score_list = []
  238. for i, feat in enumerate(feats):
  239. feat = self.stem_conv[i](feat)
  240. cls_logit = self.conv_cls[i](feat)
  241. reg_pred = self.conv_reg[i](feat)
  242. # cls prediction
  243. cls_score = F.sigmoid(cls_logit)
  244. cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
  245. # reg prediction
  246. reg_xywh, obj_logit = paddle.split(reg_pred, [4, 1], axis=1)
  247. reg_xywh = reg_xywh.flatten(2).transpose([0, 2, 1])
  248. reg_pred_list.append(reg_xywh)
  249. # obj prediction
  250. obj_score = F.sigmoid(obj_logit)
  251. obj_score_list.append(obj_score.flatten(2).transpose([0, 2, 1]))
  252. cls_score_list = paddle.concat(cls_score_list, axis=1)
  253. reg_pred_list = paddle.concat(reg_pred_list, axis=1)
  254. obj_score_list = paddle.concat(obj_score_list, axis=1)
  255. # bbox decode
  256. anchor_points, stride_tensor, _ =\
  257. self._generate_anchor_point(feat_sizes, self.fpn_strides)
  258. reg_xy, reg_wh = paddle.split(reg_pred_list, 2, axis=-1)
  259. reg_xy += (anchor_points / stride_tensor)
  260. reg_wh = paddle.exp(reg_wh) * 0.5
  261. bbox_pred_list = paddle.concat(
  262. [reg_xy - reg_wh, reg_xy + reg_wh], axis=-1)
  263. if self.training:
  264. anchor_points, stride_tensor, num_anchors_list =\
  265. self._generate_anchor_point(feat_sizes, self.fpn_strides, 0.5)
  266. yolox_losses = self.get_loss([
  267. cls_score_list, bbox_pred_list, obj_score_list, anchor_points,
  268. stride_tensor, num_anchors_list
  269. ], targets)
  270. return yolox_losses
  271. else:
  272. pred_scores = (cls_score_list * obj_score_list).sqrt()
  273. return pred_scores, bbox_pred_list, stride_tensor
  274. def get_loss(self, head_outs, targets):
  275. pred_cls, pred_bboxes, pred_obj,\
  276. anchor_points, stride_tensor, num_anchors_list = head_outs
  277. gt_labels = targets['gt_class']
  278. gt_bboxes = targets['gt_bbox']
  279. pred_scores = (pred_cls * pred_obj).sqrt()
  280. # label assignment
  281. center_and_strides = paddle.concat(
  282. [anchor_points, stride_tensor, stride_tensor], axis=-1)
  283. pos_num_list, label_list, bbox_target_list = [], [], []
  284. for pred_score, pred_bbox, gt_box, gt_label in zip(
  285. pred_scores.detach(),
  286. pred_bboxes.detach() * stride_tensor, gt_bboxes, gt_labels):
  287. pos_num, label, _, bbox_target = self.assigner(
  288. pred_score, center_and_strides, pred_bbox, gt_box, gt_label)
  289. pos_num_list.append(pos_num)
  290. label_list.append(label)
  291. bbox_target_list.append(bbox_target)
  292. labels = paddle.to_tensor(np.stack(label_list, axis=0))
  293. bbox_targets = paddle.to_tensor(np.stack(bbox_target_list, axis=0))
  294. bbox_targets /= stride_tensor # rescale bbox
  295. # 1. obj score loss
  296. mask_positive = (labels != self.num_classes)
  297. loss_obj = F.binary_cross_entropy(
  298. pred_obj,
  299. mask_positive.astype(pred_obj.dtype).unsqueeze(-1),
  300. reduction='sum')
  301. num_pos = sum(pos_num_list)
  302. if num_pos > 0:
  303. num_pos = paddle.to_tensor(num_pos, dtype=self._dtype).clip(min=1)
  304. loss_obj /= num_pos
  305. # 2. iou loss
  306. bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4])
  307. pred_bboxes_pos = paddle.masked_select(pred_bboxes,
  308. bbox_mask).reshape([-1, 4])
  309. assigned_bboxes_pos = paddle.masked_select(
  310. bbox_targets, bbox_mask).reshape([-1, 4])
  311. bbox_iou = bbox_overlaps(pred_bboxes_pos, assigned_bboxes_pos)
  312. bbox_iou = paddle.diag(bbox_iou)
  313. loss_iou = self.iou_loss(
  314. pred_bboxes_pos.split(
  315. 4, axis=-1),
  316. assigned_bboxes_pos.split(
  317. 4, axis=-1))
  318. loss_iou = loss_iou.sum() / num_pos
  319. # 3. cls loss
  320. cls_mask = mask_positive.unsqueeze(-1).tile(
  321. [1, 1, self.num_classes])
  322. pred_cls_pos = paddle.masked_select(
  323. pred_cls, cls_mask).reshape([-1, self.num_classes])
  324. assigned_cls_pos = paddle.masked_select(labels, mask_positive)
  325. assigned_cls_pos = F.one_hot(assigned_cls_pos,
  326. self.num_classes + 1)[..., :-1]
  327. assigned_cls_pos *= bbox_iou.unsqueeze(-1)
  328. loss_cls = F.binary_cross_entropy(
  329. pred_cls_pos, assigned_cls_pos, reduction='sum')
  330. loss_cls /= num_pos
  331. # 4. l1 loss
  332. if targets['epoch_id'] >= self.l1_epoch:
  333. loss_l1 = F.l1_loss(
  334. pred_bboxes_pos, assigned_bboxes_pos, reduction='sum')
  335. loss_l1 /= num_pos
  336. else:
  337. loss_l1 = paddle.zeros([1])
  338. loss_l1.stop_gradient = False
  339. else:
  340. loss_cls = paddle.zeros([1])
  341. loss_iou = paddle.zeros([1])
  342. loss_l1 = paddle.zeros([1])
  343. loss_cls.stop_gradient = False
  344. loss_iou.stop_gradient = False
  345. loss_l1.stop_gradient = False
  346. loss = self.loss_weight['obj'] * loss_obj + \
  347. self.loss_weight['cls'] * loss_cls + \
  348. self.loss_weight['iou'] * loss_iou
  349. if targets['epoch_id'] >= self.l1_epoch:
  350. loss += (self.loss_weight['l1'] * loss_l1)
  351. yolox_losses = {
  352. 'loss': loss,
  353. 'loss_cls': loss_cls,
  354. 'loss_obj': loss_obj,
  355. 'loss_iou': loss_iou,
  356. 'loss_l1': loss_l1,
  357. }
  358. return yolox_losses
  359. def post_process(self, head_outs, img_shape, scale_factor):
  360. pred_scores, pred_bboxes, stride_tensor = head_outs
  361. pred_scores = pred_scores.transpose([0, 2, 1])
  362. pred_bboxes *= stride_tensor
  363. # scale bbox to origin image
  364. scale_factor = scale_factor.flip(-1).tile([1, 2]).unsqueeze(1)
  365. pred_bboxes /= scale_factor
  366. if self.exclude_nms:
  367. # `exclude_nms=True` just use in benchmark
  368. return pred_bboxes.sum(), pred_scores.sum()
  369. else:
  370. bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
  371. return bbox_pred, bbox_num