yolox.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. #!/usr/bin/env python
  2. # -*- encoding: utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import torch.nn as nn
  5. from .yolo_head import YOLOXHead
  6. from .yolo_pafpn import YOLOPAFPN
  7. class YOLOX(nn.Module):
  8. """
  9. YOLOX model module. The module list is defined by create_yolov3_modules function.
  10. The network returns loss values from three YOLO layers during training
  11. and detection results during test.
  12. """
  13. def __init__(self, backbone=None, head=None):
  14. super().__init__()
  15. if backbone is None:
  16. backbone = YOLOPAFPN()
  17. if head is None:
  18. head = YOLOXHead(80)
  19. self.backbone = backbone
  20. self.head = head
  21. def forward(self, x, targets=None):
  22. # fpn output content features of [dark3, dark4, dark5]
  23. fpn_outs = self.backbone(x)
  24. if self.training:
  25. assert targets is not None
  26. loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(
  27. fpn_outs, targets, x
  28. )
  29. outputs = {
  30. "total_loss": loss,
  31. "iou_loss": iou_loss,
  32. "l1_loss": l1_loss,
  33. "conf_loss": conf_loss,
  34. "cls_loss": cls_loss,
  35. "num_fg": num_fg,
  36. }
  37. else:
  38. outputs = self.head(fpn_outs)
  39. return outputs