name_adapter.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. class NameAdapter(object):
  2. """Fix the backbones variable names for pretrained weight"""
  3. def __init__(self, model):
  4. super(NameAdapter, self).__init__()
  5. self.model = model
  6. @property
  7. def model_type(self):
  8. return getattr(self.model, '_model_type', '')
  9. @property
  10. def variant(self):
  11. return getattr(self.model, 'variant', '')
  12. def fix_conv_norm_name(self, name):
  13. if name == "conv1":
  14. bn_name = "bn_" + name
  15. else:
  16. bn_name = "bn" + name[3:]
  17. # the naming rule is same as pretrained weight
  18. if self.model_type == 'SEResNeXt':
  19. bn_name = name + "_bn"
  20. return bn_name
  21. def fix_shortcut_name(self, name):
  22. if self.model_type == 'SEResNeXt':
  23. name = 'conv' + name + '_prj'
  24. return name
  25. def fix_bottleneck_name(self, name):
  26. if self.model_type == 'SEResNeXt':
  27. conv_name1 = 'conv' + name + '_x1'
  28. conv_name2 = 'conv' + name + '_x2'
  29. conv_name3 = 'conv' + name + '_x3'
  30. shortcut_name = name
  31. else:
  32. conv_name1 = name + "_branch2a"
  33. conv_name2 = name + "_branch2b"
  34. conv_name3 = name + "_branch2c"
  35. shortcut_name = name + "_branch1"
  36. return conv_name1, conv_name2, conv_name3, shortcut_name
  37. def fix_basicblock_name(self, name):
  38. if self.model_type == 'SEResNeXt':
  39. conv_name1 = 'conv' + name + '_x1'
  40. conv_name2 = 'conv' + name + '_x2'
  41. shortcut_name = name
  42. else:
  43. conv_name1 = name + "_branch2a"
  44. conv_name2 = name + "_branch2b"
  45. shortcut_name = name + "_branch1"
  46. return conv_name1, conv_name2, shortcut_name
  47. def fix_layer_warp_name(self, stage_num, count, i):
  48. name = 'res' + str(stage_num)
  49. if count > 10 and stage_num == 4:
  50. if i == 0:
  51. conv_name = name + "a"
  52. else:
  53. conv_name = name + "b" + str(i)
  54. else:
  55. conv_name = name + chr(ord("a") + i)
  56. if self.model_type == 'SEResNeXt':
  57. conv_name = str(stage_num + 2) + '_' + str(i + 1)
  58. return conv_name
  59. def fix_c1_stage_name(self):
  60. return "res_conv1" if self.model_type == 'ResNeXt' else "conv1"