meta_arch.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import numpy as np
  5. import paddle
  6. import paddle.nn as nn
  7. import typing
  8. from ppdet.core.workspace import register
  9. from ppdet.modeling.post_process import nms
  10. __all__ = ['BaseArch']
  11. @register
  12. class BaseArch(nn.Layer):
  13. def __init__(self, data_format='NCHW'):
  14. super(BaseArch, self).__init__()
  15. self.data_format = data_format
  16. self.inputs = {}
  17. self.fuse_norm = False
  18. def load_meanstd(self, cfg_transform):
  19. scale = 1.
  20. mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
  21. std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
  22. for item in cfg_transform:
  23. if 'NormalizeImage' in item:
  24. mean = np.array(
  25. item['NormalizeImage']['mean'], dtype=np.float32)
  26. std = np.array(item['NormalizeImage']['std'], dtype=np.float32)
  27. if item['NormalizeImage'].get('is_scale', True):
  28. scale = 1. / 255.
  29. break
  30. if self.data_format == 'NHWC':
  31. self.scale = paddle.to_tensor(scale / std).reshape((1, 1, 1, 3))
  32. self.bias = paddle.to_tensor(-mean / std).reshape((1, 1, 1, 3))
  33. else:
  34. self.scale = paddle.to_tensor(scale / std).reshape((1, 3, 1, 1))
  35. self.bias = paddle.to_tensor(-mean / std).reshape((1, 3, 1, 1))
  36. def forward(self, inputs):
  37. if self.data_format == 'NHWC':
  38. image = inputs['image']
  39. inputs['image'] = paddle.transpose(image, [0, 2, 3, 1])
  40. if self.fuse_norm:
  41. image = inputs['image']
  42. self.inputs['image'] = image * self.scale + self.bias
  43. self.inputs['im_shape'] = inputs['im_shape']
  44. self.inputs['scale_factor'] = inputs['scale_factor']
  45. else:
  46. self.inputs = inputs
  47. self.model_arch()
  48. if self.training:
  49. out = self.get_loss()
  50. else:
  51. inputs_list = []
  52. # multi-scale input
  53. if not isinstance(inputs, typing.Sequence):
  54. inputs_list.append(inputs)
  55. else:
  56. inputs_list.extend(inputs)
  57. outs = []
  58. for inp in inputs_list:
  59. if self.fuse_norm:
  60. self.inputs['image'] = inp['image'] * self.scale + self.bias
  61. self.inputs['im_shape'] = inp['im_shape']
  62. self.inputs['scale_factor'] = inp['scale_factor']
  63. else:
  64. self.inputs = inp
  65. outs.append(self.get_pred())
  66. # multi-scale test
  67. if len(outs) > 1:
  68. out = self.merge_multi_scale_predictions(outs)
  69. else:
  70. out = outs[0]
  71. return out
  72. def merge_multi_scale_predictions(self, outs):
  73. # default values for architectures not included in following list
  74. num_classes = 80
  75. nms_threshold = 0.5
  76. keep_top_k = 100
  77. if self.__class__.__name__ in ('CascadeRCNN', 'FasterRCNN', 'MaskRCNN'):
  78. num_classes = self.bbox_head.num_classes
  79. keep_top_k = self.bbox_post_process.nms.keep_top_k
  80. nms_threshold = self.bbox_post_process.nms.nms_threshold
  81. else:
  82. raise Exception(
  83. "Multi scale test only supports CascadeRCNN, FasterRCNN and MaskRCNN for now"
  84. )
  85. final_boxes = []
  86. all_scale_outs = paddle.concat([o['bbox'] for o in outs]).numpy()
  87. for c in range(num_classes):
  88. idxs = all_scale_outs[:, 0] == c
  89. if np.count_nonzero(idxs) == 0:
  90. continue
  91. r = nms(all_scale_outs[idxs, 1:], nms_threshold)
  92. final_boxes.append(
  93. np.concatenate([np.full((r.shape[0], 1), c), r], 1))
  94. out = np.concatenate(final_boxes)
  95. out = np.concatenate(sorted(
  96. out, key=lambda e: e[1])[-keep_top_k:]).reshape((-1, 6))
  97. out = {
  98. 'bbox': paddle.to_tensor(out),
  99. 'bbox_num': paddle.to_tensor(np.array([out.shape[0], ]))
  100. }
  101. return out
  102. def build_inputs(self, data, input_def):
  103. inputs = {}
  104. for i, k in enumerate(input_def):
  105. inputs[k] = data[i]
  106. return inputs
  107. def model_arch(self, ):
  108. pass
  109. def get_loss(self, ):
  110. raise NotImplementedError("Should implement get_loss method!")
  111. def get_pred(self, ):
  112. raise NotImplementedError("Should implement get_pred method!")