name_adapter.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. class NameAdapter(object):
  15. """Fix the backbones variable names for pretrained weight"""
  16. def __init__(self, model):
  17. super(NameAdapter, self).__init__()
  18. self.model = model
  19. @property
  20. def model_type(self):
  21. return getattr(self.model, '_model_type', '')
  22. @property
  23. def variant(self):
  24. return getattr(self.model, 'variant', '')
  25. def fix_conv_norm_name(self, name):
  26. if name == "conv1":
  27. bn_name = "bn_" + name
  28. else:
  29. bn_name = "bn" + name[3:]
  30. # the naming rule is same as pretrained weight
  31. if self.model_type == 'SEResNeXt':
  32. bn_name = name + "_bn"
  33. return bn_name
  34. def fix_shortcut_name(self, name):
  35. if self.model_type == 'SEResNeXt':
  36. name = 'conv' + name + '_prj'
  37. return name
  38. def fix_bottleneck_name(self, name):
  39. if self.model_type == 'SEResNeXt':
  40. conv_name1 = 'conv' + name + '_x1'
  41. conv_name2 = 'conv' + name + '_x2'
  42. conv_name3 = 'conv' + name + '_x3'
  43. shortcut_name = name
  44. else:
  45. conv_name1 = name + "_branch2a"
  46. conv_name2 = name + "_branch2b"
  47. conv_name3 = name + "_branch2c"
  48. shortcut_name = name + "_branch1"
  49. return conv_name1, conv_name2, conv_name3, shortcut_name
  50. def fix_layer_warp_name(self, stage_num, count, i):
  51. name = 'res' + str(stage_num)
  52. if count > 10 and stage_num == 4:
  53. if i == 0:
  54. conv_name = name + "a"
  55. else:
  56. conv_name = name + "b" + str(i)
  57. else:
  58. conv_name = name + chr(ord("a") + i)
  59. if self.model_type == 'SEResNeXt':
  60. conv_name = str(stage_num + 2) + '_' + str(i + 1)
  61. return conv_name
  62. def fix_c1_stage_name(self):
  63. return "res_conv1" if self.model_type == 'ResNeXt' else "conv1"