123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- # --------------------------------------------------------
- # Swin Transformer
- # Copyright (c) 2021 Microsoft
- # Licensed under The MIT License [see LICENSE for details]
- # Written by Ze Liu
- # --------------------------------------------------------
- from .swin_transformer import SwinTransformer
- from .swin_transformer_v2 import SwinTransformerV2
- from .swin_transformer_moe import SwinTransformerMoE
- from .swin_mlp import SwinMLP
- from .simmim import build_simmim
- def build_model(config, is_pretrain=False):
- model_type = config.MODEL.TYPE
- # accelerate layernorm
- if config.FUSED_LAYERNORM:
- try:
- import apex as amp
- layernorm = amp.normalization.FusedLayerNorm
- except:
- layernorm = None
- print("To use FusedLayerNorm, please install apex.")
- else:
- import torch.nn as nn
- layernorm = nn.LayerNorm
- if is_pretrain:
- model = build_simmim(config)
- return model
- if model_type == 'swin':
- model = SwinTransformer(img_size=config.DATA.IMG_SIZE,
- patch_size=config.MODEL.SWIN.PATCH_SIZE,
- in_chans=config.MODEL.SWIN.IN_CHANS,
- num_classes=config.MODEL.NUM_CLASSES,
- embed_dim=config.MODEL.SWIN.EMBED_DIM,
- depths=config.MODEL.SWIN.DEPTHS,
- num_heads=config.MODEL.SWIN.NUM_HEADS,
- window_size=config.MODEL.SWIN.WINDOW_SIZE,
- mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
- qkv_bias=config.MODEL.SWIN.QKV_BIAS,
- qk_scale=config.MODEL.SWIN.QK_SCALE,
- drop_rate=config.MODEL.DROP_RATE,
- drop_path_rate=config.MODEL.DROP_PATH_RATE,
- ape=config.MODEL.SWIN.APE,
- norm_layer=layernorm,
- patch_norm=config.MODEL.SWIN.PATCH_NORM,
- use_checkpoint=config.TRAIN.USE_CHECKPOINT,
- fused_window_process=config.FUSED_WINDOW_PROCESS)
- elif model_type == 'swinv2':
- model = SwinTransformerV2(img_size=config.DATA.IMG_SIZE,
- patch_size=config.MODEL.SWINV2.PATCH_SIZE,
- in_chans=config.MODEL.SWINV2.IN_CHANS,
- num_classes=config.MODEL.NUM_CLASSES,
- embed_dim=config.MODEL.SWINV2.EMBED_DIM,
- depths=config.MODEL.SWINV2.DEPTHS,
- num_heads=config.MODEL.SWINV2.NUM_HEADS,
- window_size=config.MODEL.SWINV2.WINDOW_SIZE,
- mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
- qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
- drop_rate=config.MODEL.DROP_RATE,
- drop_path_rate=config.MODEL.DROP_PATH_RATE,
- ape=config.MODEL.SWINV2.APE,
- patch_norm=config.MODEL.SWINV2.PATCH_NORM,
- use_checkpoint=config.TRAIN.USE_CHECKPOINT,
- pretrained_window_sizes=config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES)
- elif model_type == 'swin_moe':
- model = SwinTransformerMoE(img_size=config.DATA.IMG_SIZE,
- patch_size=config.MODEL.SWIN_MOE.PATCH_SIZE,
- in_chans=config.MODEL.SWIN_MOE.IN_CHANS,
- num_classes=config.MODEL.NUM_CLASSES,
- embed_dim=config.MODEL.SWIN_MOE.EMBED_DIM,
- depths=config.MODEL.SWIN_MOE.DEPTHS,
- num_heads=config.MODEL.SWIN_MOE.NUM_HEADS,
- window_size=config.MODEL.SWIN_MOE.WINDOW_SIZE,
- mlp_ratio=config.MODEL.SWIN_MOE.MLP_RATIO,
- qkv_bias=config.MODEL.SWIN_MOE.QKV_BIAS,
- qk_scale=config.MODEL.SWIN_MOE.QK_SCALE,
- drop_rate=config.MODEL.DROP_RATE,
- drop_path_rate=config.MODEL.DROP_PATH_RATE,
- ape=config.MODEL.SWIN_MOE.APE,
- patch_norm=config.MODEL.SWIN_MOE.PATCH_NORM,
- mlp_fc2_bias=config.MODEL.SWIN_MOE.MLP_FC2_BIAS,
- init_std=config.MODEL.SWIN_MOE.INIT_STD,
- use_checkpoint=config.TRAIN.USE_CHECKPOINT,
- pretrained_window_sizes=config.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES,
- moe_blocks=config.MODEL.SWIN_MOE.MOE_BLOCKS,
- num_local_experts=config.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS,
- top_value=config.MODEL.SWIN_MOE.TOP_VALUE,
- capacity_factor=config.MODEL.SWIN_MOE.CAPACITY_FACTOR,
- cosine_router=config.MODEL.SWIN_MOE.COSINE_ROUTER,
- normalize_gate=config.MODEL.SWIN_MOE.NORMALIZE_GATE,
- use_bpr=config.MODEL.SWIN_MOE.USE_BPR,
- is_gshard_loss=config.MODEL.SWIN_MOE.IS_GSHARD_LOSS,
- gate_noise=config.MODEL.SWIN_MOE.GATE_NOISE,
- cosine_router_dim=config.MODEL.SWIN_MOE.COSINE_ROUTER_DIM,
- cosine_router_init_t=config.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T,
- moe_drop=config.MODEL.SWIN_MOE.MOE_DROP,
- aux_loss_weight=config.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT)
- elif model_type == 'swin_mlp':
- model = SwinMLP(img_size=config.DATA.IMG_SIZE,
- patch_size=config.MODEL.SWIN_MLP.PATCH_SIZE,
- in_chans=config.MODEL.SWIN_MLP.IN_CHANS,
- num_classes=config.MODEL.NUM_CLASSES,
- embed_dim=config.MODEL.SWIN_MLP.EMBED_DIM,
- depths=config.MODEL.SWIN_MLP.DEPTHS,
- num_heads=config.MODEL.SWIN_MLP.NUM_HEADS,
- window_size=config.MODEL.SWIN_MLP.WINDOW_SIZE,
- mlp_ratio=config.MODEL.SWIN_MLP.MLP_RATIO,
- drop_rate=config.MODEL.DROP_RATE,
- drop_path_rate=config.MODEL.DROP_PATH_RATE,
- ape=config.MODEL.SWIN_MLP.APE,
- patch_norm=config.MODEL.SWIN_MLP.PATCH_NORM,
- use_checkpoint=config.TRAIN.USE_CHECKPOINT)
- else:
- raise NotImplementedError(f"Unkown model: {model_type}")
- return model
|