123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from itertools import cycle, islice
- from collections import abc
- import paddle
- import paddle.nn as nn
- from ppdet.core.workspace import register, serializable
- __all__ = ['HrHRNetLoss', 'KeyPointMSELoss']
- @register
- @serializable
- class KeyPointMSELoss(nn.Layer):
- def __init__(self, use_target_weight=True, loss_scale=0.5):
- """
- KeyPointMSELoss layer
- Args:
- use_target_weight (bool): whether to use target weight
- """
- super(KeyPointMSELoss, self).__init__()
- self.criterion = nn.MSELoss(reduction='mean')
- self.use_target_weight = use_target_weight
- self.loss_scale = loss_scale
- def forward(self, output, records):
- target = records['target']
- target_weight = records['target_weight']
- batch_size = output.shape[0]
- num_joints = output.shape[1]
- heatmaps_pred = output.reshape(
- (batch_size, num_joints, -1)).split(num_joints, 1)
- heatmaps_gt = target.reshape(
- (batch_size, num_joints, -1)).split(num_joints, 1)
- loss = 0
- for idx in range(num_joints):
- heatmap_pred = heatmaps_pred[idx].squeeze()
- heatmap_gt = heatmaps_gt[idx].squeeze()
- if self.use_target_weight:
- loss += self.loss_scale * self.criterion(
- heatmap_pred.multiply(target_weight[:, idx]),
- heatmap_gt.multiply(target_weight[:, idx]))
- else:
- loss += self.loss_scale * self.criterion(heatmap_pred,
- heatmap_gt)
- keypoint_losses = dict()
- keypoint_losses['loss'] = loss / num_joints
- return keypoint_losses
- @register
- @serializable
- class HrHRNetLoss(nn.Layer):
- def __init__(self, num_joints, swahr):
- """
- HrHRNetLoss layer
- Args:
- num_joints (int): number of keypoints
- """
- super(HrHRNetLoss, self).__init__()
- if swahr:
- self.heatmaploss = HeatMapSWAHRLoss(num_joints)
- else:
- self.heatmaploss = HeatMapLoss()
- self.aeloss = AELoss()
- self.ziploss = ZipLoss(
- [self.heatmaploss, self.heatmaploss, self.aeloss])
- def forward(self, inputs, records):
- targets = []
- targets.append([records['heatmap_gt1x'], records['mask_1x']])
- targets.append([records['heatmap_gt2x'], records['mask_2x']])
- targets.append(records['tagmap'])
- keypoint_losses = dict()
- loss = self.ziploss(inputs, targets)
- keypoint_losses['heatmap_loss'] = loss[0] + loss[1]
- keypoint_losses['pull_loss'] = loss[2][0]
- keypoint_losses['push_loss'] = loss[2][1]
- keypoint_losses['loss'] = recursive_sum(loss)
- return keypoint_losses
- class HeatMapLoss(object):
- def __init__(self, loss_factor=1.0):
- super(HeatMapLoss, self).__init__()
- self.loss_factor = loss_factor
- def __call__(self, preds, targets):
- heatmap, mask = targets
- loss = ((preds - heatmap)**2 * mask.cast('float').unsqueeze(1))
- loss = paddle.clip(loss, min=0, max=2).mean()
- loss *= self.loss_factor
- return loss
- class HeatMapSWAHRLoss(object):
- def __init__(self, num_joints, loss_factor=1.0):
- super(HeatMapSWAHRLoss, self).__init__()
- self.loss_factor = loss_factor
- self.num_joints = num_joints
- def __call__(self, preds, targets):
- heatmaps_gt, mask = targets
- heatmaps_pred = preds[0]
- scalemaps_pred = preds[1]
- heatmaps_scaled_gt = paddle.where(heatmaps_gt > 0, 0.5 * heatmaps_gt * (
- 1 + (1 +
- (scalemaps_pred - 1.) * paddle.log(heatmaps_gt + 1e-10))**2),
- heatmaps_gt)
- regularizer_loss = paddle.mean(
- paddle.pow((scalemaps_pred - 1.) * (heatmaps_gt > 0).astype(float),
- 2))
- omiga = 0.01
- # thres = 2**(-1/omiga), threshold for positive weight
- hm_weight = heatmaps_scaled_gt**(
- omiga
- ) * paddle.abs(1 - heatmaps_pred) + paddle.abs(heatmaps_pred) * (
- 1 - heatmaps_scaled_gt**(omiga))
- loss = (((heatmaps_pred - heatmaps_scaled_gt)**2) *
- mask.cast('float').unsqueeze(1)) * hm_weight
- loss = loss.mean()
- loss = self.loss_factor * (loss + 1.0 * regularizer_loss)
- return loss
- class AELoss(object):
- def __init__(self, pull_factor=0.001, push_factor=0.001):
- super(AELoss, self).__init__()
- self.pull_factor = pull_factor
- self.push_factor = push_factor
- def apply_single(self, pred, tagmap):
- if tagmap.numpy()[:, :, 3].sum() == 0:
- return (paddle.zeros([1]), paddle.zeros([1]))
- nonzero = paddle.nonzero(tagmap[:, :, 3] > 0)
- if nonzero.shape[0] == 0:
- return (paddle.zeros([1]), paddle.zeros([1]))
- p_inds = paddle.unique(nonzero[:, 0])
- num_person = p_inds.shape[0]
- if num_person == 0:
- return (paddle.zeros([1]), paddle.zeros([1]))
- pull = 0
- tagpull_num = 0
- embs_all = []
- person_unvalid = 0
- for person_idx in p_inds.numpy():
- valid_single = tagmap[person_idx.item()]
- validkpts = paddle.nonzero(valid_single[:, 3] > 0)
- valid_single = paddle.index_select(valid_single, validkpts)
- emb = paddle.gather_nd(pred, valid_single[:, :3])
- if emb.shape[0] == 1:
- person_unvalid += 1
- mean = paddle.mean(emb, axis=0)
- embs_all.append(mean)
- pull += paddle.mean(paddle.pow(emb - mean, 2), axis=0)
- tagpull_num += emb.shape[0]
- pull /= max(num_person - person_unvalid, 1)
- if num_person < 2:
- return pull, paddle.zeros([1])
- embs_all = paddle.stack(embs_all)
- A = embs_all.expand([num_person, num_person])
- B = A.transpose([1, 0])
- diff = A - B
- diff = paddle.pow(diff, 2)
- push = paddle.exp(-diff)
- push = paddle.sum(push) - num_person
- push /= 2 * num_person * (num_person - 1)
- return pull, push
- def __call__(self, preds, tagmaps):
- bs = preds.shape[0]
- losses = [
- self.apply_single(preds[i:i + 1].squeeze(),
- tagmaps[i:i + 1].squeeze()) for i in range(bs)
- ]
- pull = self.pull_factor * sum(loss[0] for loss in losses) / len(losses)
- push = self.push_factor * sum(loss[1] for loss in losses) / len(losses)
- return pull, push
- class ZipLoss(object):
- def __init__(self, loss_funcs):
- super(ZipLoss, self).__init__()
- self.loss_funcs = loss_funcs
- def __call__(self, inputs, targets):
- assert len(self.loss_funcs) == len(targets) >= len(inputs)
- def zip_repeat(*args):
- longest = max(map(len, args))
- filled = [islice(cycle(x), longest) for x in args]
- return zip(*filled)
- return tuple(
- fn(x, y)
- for x, y, fn in zip_repeat(inputs, targets, self.loss_funcs))
- def recursive_sum(inputs):
- if isinstance(inputs, abc.Sequence):
- return sum([recursive_sum(x) for x in inputs])
- return inputs
|