build.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # --------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. from .swin_transformer import SwinTransformer
  8. from .swin_transformer_v2 import SwinTransformerV2
  9. from .swin_transformer_moe import SwinTransformerMoE
  10. from .swin_mlp import SwinMLP
  11. from .simmim import build_simmim
  12. def build_model(config, is_pretrain=False):
  13. model_type = config.MODEL.TYPE
  14. # accelerate layernorm
  15. if config.FUSED_LAYERNORM:
  16. try:
  17. import apex as amp
  18. layernorm = amp.normalization.FusedLayerNorm
  19. except:
  20. layernorm = None
  21. print("To use FusedLayerNorm, please install apex.")
  22. else:
  23. import torch.nn as nn
  24. layernorm = nn.LayerNorm
  25. if is_pretrain:
  26. model = build_simmim(config)
  27. return model
  28. if model_type == 'swin':
  29. model = SwinTransformer(img_size=config.DATA.IMG_SIZE,
  30. patch_size=config.MODEL.SWIN.PATCH_SIZE,
  31. in_chans=config.MODEL.SWIN.IN_CHANS,
  32. num_classes=config.MODEL.NUM_CLASSES,
  33. embed_dim=config.MODEL.SWIN.EMBED_DIM,
  34. depths=config.MODEL.SWIN.DEPTHS,
  35. num_heads=config.MODEL.SWIN.NUM_HEADS,
  36. window_size=config.MODEL.SWIN.WINDOW_SIZE,
  37. mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
  38. qkv_bias=config.MODEL.SWIN.QKV_BIAS,
  39. qk_scale=config.MODEL.SWIN.QK_SCALE,
  40. drop_rate=config.MODEL.DROP_RATE,
  41. drop_path_rate=config.MODEL.DROP_PATH_RATE,
  42. ape=config.MODEL.SWIN.APE,
  43. norm_layer=layernorm,
  44. patch_norm=config.MODEL.SWIN.PATCH_NORM,
  45. use_checkpoint=config.TRAIN.USE_CHECKPOINT,
  46. fused_window_process=config.FUSED_WINDOW_PROCESS)
  47. elif model_type == 'swinv2':
  48. model = SwinTransformerV2(img_size=config.DATA.IMG_SIZE,
  49. patch_size=config.MODEL.SWINV2.PATCH_SIZE,
  50. in_chans=config.MODEL.SWINV2.IN_CHANS,
  51. num_classes=config.MODEL.NUM_CLASSES,
  52. embed_dim=config.MODEL.SWINV2.EMBED_DIM,
  53. depths=config.MODEL.SWINV2.DEPTHS,
  54. num_heads=config.MODEL.SWINV2.NUM_HEADS,
  55. window_size=config.MODEL.SWINV2.WINDOW_SIZE,
  56. mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
  57. qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
  58. drop_rate=config.MODEL.DROP_RATE,
  59. drop_path_rate=config.MODEL.DROP_PATH_RATE,
  60. ape=config.MODEL.SWINV2.APE,
  61. patch_norm=config.MODEL.SWINV2.PATCH_NORM,
  62. use_checkpoint=config.TRAIN.USE_CHECKPOINT,
  63. pretrained_window_sizes=config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES)
  64. elif model_type == 'swin_moe':
  65. model = SwinTransformerMoE(img_size=config.DATA.IMG_SIZE,
  66. patch_size=config.MODEL.SWIN_MOE.PATCH_SIZE,
  67. in_chans=config.MODEL.SWIN_MOE.IN_CHANS,
  68. num_classes=config.MODEL.NUM_CLASSES,
  69. embed_dim=config.MODEL.SWIN_MOE.EMBED_DIM,
  70. depths=config.MODEL.SWIN_MOE.DEPTHS,
  71. num_heads=config.MODEL.SWIN_MOE.NUM_HEADS,
  72. window_size=config.MODEL.SWIN_MOE.WINDOW_SIZE,
  73. mlp_ratio=config.MODEL.SWIN_MOE.MLP_RATIO,
  74. qkv_bias=config.MODEL.SWIN_MOE.QKV_BIAS,
  75. qk_scale=config.MODEL.SWIN_MOE.QK_SCALE,
  76. drop_rate=config.MODEL.DROP_RATE,
  77. drop_path_rate=config.MODEL.DROP_PATH_RATE,
  78. ape=config.MODEL.SWIN_MOE.APE,
  79. patch_norm=config.MODEL.SWIN_MOE.PATCH_NORM,
  80. mlp_fc2_bias=config.MODEL.SWIN_MOE.MLP_FC2_BIAS,
  81. init_std=config.MODEL.SWIN_MOE.INIT_STD,
  82. use_checkpoint=config.TRAIN.USE_CHECKPOINT,
  83. pretrained_window_sizes=config.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES,
  84. moe_blocks=config.MODEL.SWIN_MOE.MOE_BLOCKS,
  85. num_local_experts=config.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS,
  86. top_value=config.MODEL.SWIN_MOE.TOP_VALUE,
  87. capacity_factor=config.MODEL.SWIN_MOE.CAPACITY_FACTOR,
  88. cosine_router=config.MODEL.SWIN_MOE.COSINE_ROUTER,
  89. normalize_gate=config.MODEL.SWIN_MOE.NORMALIZE_GATE,
  90. use_bpr=config.MODEL.SWIN_MOE.USE_BPR,
  91. is_gshard_loss=config.MODEL.SWIN_MOE.IS_GSHARD_LOSS,
  92. gate_noise=config.MODEL.SWIN_MOE.GATE_NOISE,
  93. cosine_router_dim=config.MODEL.SWIN_MOE.COSINE_ROUTER_DIM,
  94. cosine_router_init_t=config.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T,
  95. moe_drop=config.MODEL.SWIN_MOE.MOE_DROP,
  96. aux_loss_weight=config.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT)
  97. elif model_type == 'swin_mlp':
  98. model = SwinMLP(img_size=config.DATA.IMG_SIZE,
  99. patch_size=config.MODEL.SWIN_MLP.PATCH_SIZE,
  100. in_chans=config.MODEL.SWIN_MLP.IN_CHANS,
  101. num_classes=config.MODEL.NUM_CLASSES,
  102. embed_dim=config.MODEL.SWIN_MLP.EMBED_DIM,
  103. depths=config.MODEL.SWIN_MLP.DEPTHS,
  104. num_heads=config.MODEL.SWIN_MLP.NUM_HEADS,
  105. window_size=config.MODEL.SWIN_MLP.WINDOW_SIZE,
  106. mlp_ratio=config.MODEL.SWIN_MLP.MLP_RATIO,
  107. drop_rate=config.MODEL.DROP_RATE,
  108. drop_path_rate=config.MODEL.DROP_PATH_RATE,
  109. ape=config.MODEL.SWIN_MLP.APE,
  110. patch_norm=config.MODEL.SWIN_MLP.PATCH_NORM,
  111. use_checkpoint=config.TRAIN.USE_CHECKPOINT)
  112. else:
  113. raise NotImplementedError(f"Unkown model: {model_type}")
  114. return model