utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # Modified from DETR (https://github.com/facebookresearch/detr)
  16. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import copy
  21. import paddle
  22. import paddle.nn as nn
  23. import paddle.nn.functional as F
  24. from ..bbox_utils import bbox_overlaps
  25. __all__ = [
  26. '_get_clones', 'bbox_overlaps', 'bbox_cxcywh_to_xyxy',
  27. 'bbox_xyxy_to_cxcywh', 'sigmoid_focal_loss', 'inverse_sigmoid',
  28. 'deformable_attention_core_func'
  29. ]
  30. def _get_clones(module, N):
  31. return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
  32. def bbox_cxcywh_to_xyxy(x):
  33. x_c, y_c, w, h = x.unbind(-1)
  34. b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
  35. return paddle.stack(b, axis=-1)
  36. def bbox_xyxy_to_cxcywh(x):
  37. x0, y0, x1, y1 = x.unbind(-1)
  38. b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
  39. return paddle.stack(b, axis=-1)
  40. def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0):
  41. prob = F.sigmoid(logit)
  42. ce_loss = F.binary_cross_entropy_with_logits(logit, label, reduction="none")
  43. p_t = prob * label + (1 - prob) * (1 - label)
  44. loss = ce_loss * ((1 - p_t)**gamma)
  45. if alpha >= 0:
  46. alpha_t = alpha * label + (1 - alpha) * (1 - label)
  47. loss = alpha_t * loss
  48. return loss.mean(1).sum() / normalizer
  49. def inverse_sigmoid(x, eps=1e-6):
  50. x = x.clip(min=0., max=1.)
  51. return paddle.log(x / (1 - x + eps) + eps)
  52. def deformable_attention_core_func(value, value_spatial_shapes,
  53. sampling_locations, attention_weights):
  54. """
  55. Args:
  56. value (Tensor): [bs, value_length, n_head, c]
  57. value_spatial_shapes (Tensor): [n_levels, 2]
  58. sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
  59. attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]
  60. Returns:
  61. output (Tensor): [bs, Length_{query}, C]
  62. """
  63. bs, Len_v, n_head, c = value.shape
  64. _, Len_q, n_head, n_levels, n_points, _ = sampling_locations.shape
  65. value_list = value.split(value_spatial_shapes.prod(1).tolist(), axis=1)
  66. sampling_grids = 2 * sampling_locations - 1
  67. sampling_value_list = []
  68. for level, (h, w) in enumerate(value_spatial_shapes.tolist()):
  69. # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
  70. value_l_ = value_list[level].flatten(2).transpose(
  71. [0, 2, 1]).reshape([bs * n_head, c, h, w])
  72. # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
  73. sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(
  74. [0, 2, 1, 3, 4]).flatten(0, 1)
  75. # N_*M_, D_, Lq_, P_
  76. sampling_value_l_ = F.grid_sample(
  77. value_l_,
  78. sampling_grid_l_,
  79. mode='bilinear',
  80. padding_mode='zeros',
  81. align_corners=False)
  82. sampling_value_list.append(sampling_value_l_)
  83. # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_)
  84. attention_weights = attention_weights.transpose([0, 2, 1, 3, 4]).reshape(
  85. [bs * n_head, 1, Len_q, n_levels * n_points])
  86. output = (paddle.stack(
  87. sampling_value_list, axis=-2).flatten(-2) *
  88. attention_weights).sum(-1).reshape([bs, n_head * c, Len_q])
  89. return output.transpose([0, 2, 1])