123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- from ppdet.core.workspace import register, create, load_config
- from ppdet.modeling import ops
- from ppdet.utils.checkpoint import load_pretrain_weight
- from ppdet.utils.logger import setup_logger
- logger = setup_logger(__name__)
- class DistillModel(nn.Layer):
- def __init__(self, cfg, slim_cfg):
- super(DistillModel, self).__init__()
- self.student_model = create(cfg.architecture)
- logger.debug('Load student model pretrain_weights:{}'.format(
- cfg.pretrain_weights))
- load_pretrain_weight(self.student_model, cfg.pretrain_weights)
- slim_cfg = load_config(slim_cfg)
- self.teacher_model = create(slim_cfg.architecture)
- self.distill_loss = create(slim_cfg.distill_loss)
- logger.debug('Load teacher model pretrain_weights:{}'.format(
- slim_cfg.pretrain_weights))
- load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights)
- for param in self.teacher_model.parameters():
- param.trainable = False
- def parameters(self):
- return self.student_model.parameters()
- def forward(self, inputs):
- if self.training:
- teacher_loss = self.teacher_model(inputs)
- student_loss = self.student_model(inputs)
- loss = self.distill_loss(self.teacher_model, self.student_model)
- student_loss['distill_loss'] = loss
- student_loss['teacher_loss'] = teacher_loss['loss']
- student_loss['loss'] += student_loss['distill_loss']
- return student_loss
- else:
- return self.student_model(inputs)
- @register
- class DistillYOLOv3Loss(nn.Layer):
- def __init__(self, weight=1000):
- super(DistillYOLOv3Loss, self).__init__()
- self.weight = weight
- def obj_weighted_reg(self, sx, sy, sw, sh, tx, ty, tw, th, tobj):
- loss_x = ops.sigmoid_cross_entropy_with_logits(sx, F.sigmoid(tx))
- loss_y = ops.sigmoid_cross_entropy_with_logits(sy, F.sigmoid(ty))
- loss_w = paddle.abs(sw - tw)
- loss_h = paddle.abs(sh - th)
- loss = paddle.add_n([loss_x, loss_y, loss_w, loss_h])
- weighted_loss = paddle.mean(loss * F.sigmoid(tobj))
- return weighted_loss
- def obj_weighted_cls(self, scls, tcls, tobj):
- loss = ops.sigmoid_cross_entropy_with_logits(scls, F.sigmoid(tcls))
- weighted_loss = paddle.mean(paddle.multiply(loss, F.sigmoid(tobj)))
- return weighted_loss
- def obj_loss(self, sobj, tobj):
- obj_mask = paddle.cast(tobj > 0., dtype="float32")
- obj_mask.stop_gradient = True
- loss = paddle.mean(
- ops.sigmoid_cross_entropy_with_logits(sobj, obj_mask))
- return loss
- def forward(self, teacher_model, student_model):
- teacher_distill_pairs = teacher_model.yolo_head.loss.distill_pairs
- student_distill_pairs = student_model.yolo_head.loss.distill_pairs
- distill_reg_loss, distill_cls_loss, distill_obj_loss = [], [], []
- for s_pair, t_pair in zip(student_distill_pairs, teacher_distill_pairs):
- distill_reg_loss.append(
- self.obj_weighted_reg(s_pair[0], s_pair[1], s_pair[2], s_pair[
- 3], t_pair[0], t_pair[1], t_pair[2], t_pair[3], t_pair[4]))
- distill_cls_loss.append(
- self.obj_weighted_cls(s_pair[5], t_pair[5], t_pair[4]))
- distill_obj_loss.append(self.obj_loss(s_pair[4], t_pair[4]))
- distill_reg_loss = paddle.add_n(distill_reg_loss)
- distill_cls_loss = paddle.add_n(distill_cls_loss)
- distill_obj_loss = paddle.add_n(distill_obj_loss)
- loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss
- ) * self.weight
- return loss
|