keypoint_hrhrnet_head.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. from ppdet.core.workspace import register
  17. from .. import layers as L
  18. from ..backbones.hrnet import BasicBlock
  19. @register
  20. class HrHRNetHead(nn.Layer):
  21. __inject__ = ['loss']
  22. def __init__(self, num_joints, loss='HrHRNetLoss', swahr=False, width=32):
  23. """
  24. Head for HigherHRNet network
  25. Args:
  26. num_joints (int): number of keypoints
  27. hrloss (object): HrHRNetLoss instance
  28. swahr (bool): whether to use swahr
  29. width (int): hrnet channel width
  30. """
  31. super(HrHRNetHead, self).__init__()
  32. self.loss = loss
  33. self.num_joints = num_joints
  34. num_featout1 = num_joints * 2
  35. num_featout2 = num_joints
  36. self.swahr = swahr
  37. self.conv1 = L.Conv2d(width, num_featout1, 1, 1, 0, bias=True)
  38. self.conv2 = L.Conv2d(width, num_featout2, 1, 1, 0, bias=True)
  39. self.deconv = nn.Sequential(
  40. L.ConvTranspose2d(
  41. num_featout1 + width, width, 4, 2, 1, 0, bias=False),
  42. L.BatchNorm2d(width),
  43. L.ReLU())
  44. self.blocks = nn.Sequential(*(BasicBlock(
  45. num_channels=width,
  46. num_filters=width,
  47. has_se=False,
  48. freeze_norm=False,
  49. name='HrHRNetHead_{}'.format(i)) for i in range(4)))
  50. self.interpolate = L.Upsample(2, mode='bilinear')
  51. self.concat = L.Concat(dim=1)
  52. if swahr:
  53. self.scalelayer0 = nn.Sequential(
  54. L.Conv2d(
  55. width, num_joints, 1, 1, 0, bias=True),
  56. L.BatchNorm2d(num_joints),
  57. L.ReLU(),
  58. L.Conv2d(
  59. num_joints,
  60. num_joints,
  61. 9,
  62. 1,
  63. 4,
  64. groups=num_joints,
  65. bias=True))
  66. self.scalelayer1 = nn.Sequential(
  67. L.Conv2d(
  68. width, num_joints, 1, 1, 0, bias=True),
  69. L.BatchNorm2d(num_joints),
  70. L.ReLU(),
  71. L.Conv2d(
  72. num_joints,
  73. num_joints,
  74. 9,
  75. 1,
  76. 4,
  77. groups=num_joints,
  78. bias=True))
  79. def forward(self, feats, targets=None):
  80. x1 = feats[0]
  81. xo1 = self.conv1(x1)
  82. x2 = self.blocks(self.deconv(self.concat((x1, xo1))))
  83. xo2 = self.conv2(x2)
  84. num_joints = self.num_joints
  85. if self.training:
  86. heatmap1, tagmap = paddle.split(xo1, 2, axis=1)
  87. if self.swahr:
  88. so1 = self.scalelayer0(x1)
  89. so2 = self.scalelayer1(x2)
  90. hrhrnet_outputs = ([heatmap1, so1], [xo2, so2], tagmap)
  91. return self.loss(hrhrnet_outputs, targets)
  92. else:
  93. hrhrnet_outputs = (heatmap1, xo2, tagmap)
  94. return self.loss(hrhrnet_outputs, targets)
  95. # averaged heatmap, upsampled tagmap
  96. upsampled = self.interpolate(xo1)
  97. avg = (upsampled[:, :num_joints] + xo2[:, :num_joints]) / 2
  98. return avg, upsampled[:, num_joints:]