distill.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from ppdet.core.workspace import register, create, load_config
  21. from ppdet.modeling import ops
  22. from ppdet.utils.checkpoint import load_pretrain_weight
  23. from ppdet.utils.logger import setup_logger
  24. logger = setup_logger(__name__)
  25. class DistillModel(nn.Layer):
  26. def __init__(self, cfg, slim_cfg):
  27. super(DistillModel, self).__init__()
  28. self.student_model = create(cfg.architecture)
  29. logger.debug('Load student model pretrain_weights:{}'.format(
  30. cfg.pretrain_weights))
  31. load_pretrain_weight(self.student_model, cfg.pretrain_weights)
  32. slim_cfg = load_config(slim_cfg)
  33. self.teacher_model = create(slim_cfg.architecture)
  34. self.distill_loss = create(slim_cfg.distill_loss)
  35. logger.debug('Load teacher model pretrain_weights:{}'.format(
  36. slim_cfg.pretrain_weights))
  37. load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights)
  38. for param in self.teacher_model.parameters():
  39. param.trainable = False
  40. def parameters(self):
  41. return self.student_model.parameters()
  42. def forward(self, inputs):
  43. if self.training:
  44. teacher_loss = self.teacher_model(inputs)
  45. student_loss = self.student_model(inputs)
  46. loss = self.distill_loss(self.teacher_model, self.student_model)
  47. student_loss['distill_loss'] = loss
  48. student_loss['teacher_loss'] = teacher_loss['loss']
  49. student_loss['loss'] += student_loss['distill_loss']
  50. return student_loss
  51. else:
  52. return self.student_model(inputs)
  53. @register
  54. class DistillYOLOv3Loss(nn.Layer):
  55. def __init__(self, weight=1000):
  56. super(DistillYOLOv3Loss, self).__init__()
  57. self.weight = weight
  58. def obj_weighted_reg(self, sx, sy, sw, sh, tx, ty, tw, th, tobj):
  59. loss_x = ops.sigmoid_cross_entropy_with_logits(sx, F.sigmoid(tx))
  60. loss_y = ops.sigmoid_cross_entropy_with_logits(sy, F.sigmoid(ty))
  61. loss_w = paddle.abs(sw - tw)
  62. loss_h = paddle.abs(sh - th)
  63. loss = paddle.add_n([loss_x, loss_y, loss_w, loss_h])
  64. weighted_loss = paddle.mean(loss * F.sigmoid(tobj))
  65. return weighted_loss
  66. def obj_weighted_cls(self, scls, tcls, tobj):
  67. loss = ops.sigmoid_cross_entropy_with_logits(scls, F.sigmoid(tcls))
  68. weighted_loss = paddle.mean(paddle.multiply(loss, F.sigmoid(tobj)))
  69. return weighted_loss
  70. def obj_loss(self, sobj, tobj):
  71. obj_mask = paddle.cast(tobj > 0., dtype="float32")
  72. obj_mask.stop_gradient = True
  73. loss = paddle.mean(
  74. ops.sigmoid_cross_entropy_with_logits(sobj, obj_mask))
  75. return loss
  76. def forward(self, teacher_model, student_model):
  77. teacher_distill_pairs = teacher_model.yolo_head.loss.distill_pairs
  78. student_distill_pairs = student_model.yolo_head.loss.distill_pairs
  79. distill_reg_loss, distill_cls_loss, distill_obj_loss = [], [], []
  80. for s_pair, t_pair in zip(student_distill_pairs, teacher_distill_pairs):
  81. distill_reg_loss.append(
  82. self.obj_weighted_reg(s_pair[0], s_pair[1], s_pair[2], s_pair[
  83. 3], t_pair[0], t_pair[1], t_pair[2], t_pair[3], t_pair[4]))
  84. distill_cls_loss.append(
  85. self.obj_weighted_cls(s_pair[5], t_pair[5], t_pair[4]))
  86. distill_obj_loss.append(self.obj_loss(s_pair[4], t_pair[4]))
  87. distill_reg_loss = paddle.add_n(distill_reg_loss)
  88. distill_cls_loss = paddle.add_n(distill_cls_loss)
  89. distill_obj_loss = paddle.add_n(distill_obj_loss)
  90. loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss
  91. ) * self.weight
  92. return loss