config.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # --------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------'
  7. import os
  8. import torch
  9. import yaml
  10. from yacs.config import CfgNode as CN
  11. # pytorch major version (1.x or 2.x)
  12. PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
  13. _C = CN()
  14. # Base config files
  15. _C.BASE = ['']
  16. # -----------------------------------------------------------------------------
  17. # Data settings
  18. # -----------------------------------------------------------------------------
  19. _C.DATA = CN()
  20. # Batch size for a single GPU, could be overwritten by command line argument
  21. _C.DATA.BATCH_SIZE = 128
  22. # Path to dataset, could be overwritten by command line argument
  23. _C.DATA.DATA_PATH = '/data/fengyang/sunwin/code/swin_conda_env/Swin-Transformer/imagenet/'
  24. # Dataset name
  25. _C.DATA.DATASET = 'imagenet'
  26. # Input image size
  27. _C.DATA.IMG_SIZE = 224
  28. # Interpolation to resize image (random, bilinear, bicubic)
  29. _C.DATA.INTERPOLATION = 'bicubic'
  30. # Use zipped dataset instead of folder dataset
  31. # could be overwritten by command line argument
  32. _C.DATA.ZIP_MODE = False
  33. # Cache Data in Memory, could be overwritten by command line argument
  34. _C.DATA.CACHE_MODE = 'part'
  35. # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
  36. _C.DATA.PIN_MEMORY = True
  37. # Number of data loading threads
  38. _C.DATA.NUM_WORKERS = 8
  39. # [SimMIM] Mask patch size for MaskGenerator
  40. _C.DATA.MASK_PATCH_SIZE = 32
  41. # [SimMIM] Mask ratio for MaskGenerator
  42. _C.DATA.MASK_RATIO = 0.6
  43. # -----------------------------------------------------------------------------
  44. # Model settings
  45. # -----------------------------------------------------------------------------
  46. _C.MODEL = CN()
  47. # Model type
  48. _C.MODEL.TYPE = 'swin'
  49. # Model name
  50. _C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
  51. # Pretrained weight from checkpoint, could be imagenet22k pretrained weight
  52. # could be overwritten by command line argument
  53. _C.MODEL.PRETRAINED = ''
  54. # Checkpoint to resume, could be overwritten by command line argument
  55. _C.MODEL.RESUME = '/data/fengyang/sunwin/code/swin_conda_env/Swin-Transformer/swin_tiny_patch4_window7_224.pth'
  56. # Number of classes, overwritten in data preparation
  57. _C.MODEL.NUM_CLASSES = 2
  58. # Dropout rate
  59. _C.MODEL.DROP_RATE = 0.0
  60. # Drop path rate
  61. _C.MODEL.DROP_PATH_RATE = 0.1
  62. # Label Smoothing
  63. _C.MODEL.LABEL_SMOOTHING = 0.1
  64. # Swin Transformer parameters
  65. _C.MODEL.SWIN = CN()
  66. _C.MODEL.SWIN.PATCH_SIZE = 4
  67. _C.MODEL.SWIN.IN_CHANS = 3
  68. _C.MODEL.SWIN.EMBED_DIM = 96
  69. _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
  70. _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
  71. _C.MODEL.SWIN.WINDOW_SIZE = 7
  72. _C.MODEL.SWIN.MLP_RATIO = 4.
  73. _C.MODEL.SWIN.QKV_BIAS = True
  74. _C.MODEL.SWIN.QK_SCALE = None
  75. _C.MODEL.SWIN.APE = False
  76. _C.MODEL.SWIN.PATCH_NORM = True
  77. # Swin Transformer V2 parameters
  78. _C.MODEL.SWINV2 = CN()
  79. _C.MODEL.SWINV2.PATCH_SIZE = 4
  80. _C.MODEL.SWINV2.IN_CHANS = 3
  81. _C.MODEL.SWINV2.EMBED_DIM = 96
  82. _C.MODEL.SWINV2.DEPTHS = [2, 2, 6, 2]
  83. _C.MODEL.SWINV2.NUM_HEADS = [3, 6, 12, 24]
  84. _C.MODEL.SWINV2.WINDOW_SIZE = 7
  85. _C.MODEL.SWINV2.MLP_RATIO = 4.
  86. _C.MODEL.SWINV2.QKV_BIAS = True
  87. _C.MODEL.SWINV2.APE = False
  88. _C.MODEL.SWINV2.PATCH_NORM = True
  89. _C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0]
  90. # Swin Transformer MoE parameters
  91. _C.MODEL.SWIN_MOE = CN()
  92. _C.MODEL.SWIN_MOE.PATCH_SIZE = 4
  93. _C.MODEL.SWIN_MOE.IN_CHANS = 3
  94. _C.MODEL.SWIN_MOE.EMBED_DIM = 96
  95. _C.MODEL.SWIN_MOE.DEPTHS = [2, 2, 6, 2]
  96. _C.MODEL.SWIN_MOE.NUM_HEADS = [3, 6, 12, 24]
  97. _C.MODEL.SWIN_MOE.WINDOW_SIZE = 7
  98. _C.MODEL.SWIN_MOE.MLP_RATIO = 4.
  99. _C.MODEL.SWIN_MOE.QKV_BIAS = True
  100. _C.MODEL.SWIN_MOE.QK_SCALE = None
  101. _C.MODEL.SWIN_MOE.APE = False
  102. _C.MODEL.SWIN_MOE.PATCH_NORM = True
  103. _C.MODEL.SWIN_MOE.MLP_FC2_BIAS = True
  104. _C.MODEL.SWIN_MOE.INIT_STD = 0.02
  105. _C.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0]
  106. _C.MODEL.SWIN_MOE.MOE_BLOCKS = [[-1], [-1], [-1], [-1]]
  107. _C.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS = 1
  108. _C.MODEL.SWIN_MOE.TOP_VALUE = 1
  109. _C.MODEL.SWIN_MOE.CAPACITY_FACTOR = 1.25
  110. _C.MODEL.SWIN_MOE.COSINE_ROUTER = False
  111. _C.MODEL.SWIN_MOE.NORMALIZE_GATE = False
  112. _C.MODEL.SWIN_MOE.USE_BPR = True
  113. _C.MODEL.SWIN_MOE.IS_GSHARD_LOSS = False
  114. _C.MODEL.SWIN_MOE.GATE_NOISE = 1.0
  115. _C.MODEL.SWIN_MOE.COSINE_ROUTER_DIM = 256
  116. _C.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T = 0.5
  117. _C.MODEL.SWIN_MOE.MOE_DROP = 0.0
  118. _C.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT = 0.01
  119. # Swin MLP parameters
  120. _C.MODEL.SWIN_MLP = CN()
  121. _C.MODEL.SWIN_MLP.PATCH_SIZE = 4
  122. _C.MODEL.SWIN_MLP.IN_CHANS = 3
  123. _C.MODEL.SWIN_MLP.EMBED_DIM = 96
  124. _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2]
  125. _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24]
  126. _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7
  127. _C.MODEL.SWIN_MLP.MLP_RATIO = 4.
  128. _C.MODEL.SWIN_MLP.APE = False
  129. _C.MODEL.SWIN_MLP.PATCH_NORM = True
  130. # [SimMIM] Norm target during training
  131. _C.MODEL.SIMMIM = CN()
  132. _C.MODEL.SIMMIM.NORM_TARGET = CN()
  133. _C.MODEL.SIMMIM.NORM_TARGET.ENABLE = False
  134. _C.MODEL.SIMMIM.NORM_TARGET.PATCH_SIZE = 47
  135. # -----------------------------------------------------------------------------
  136. # Training settings
  137. # -----------------------------------------------------------------------------
  138. _C.TRAIN = CN()
  139. _C.TRAIN.START_EPOCH = 0
  140. _C.TRAIN.EPOCHS = 300
  141. _C.TRAIN.WARMUP_EPOCHS = 20
  142. _C.TRAIN.WEIGHT_DECAY = 0.05
  143. _C.TRAIN.BASE_LR = 5e-4
  144. _C.TRAIN.WARMUP_LR = 5e-7
  145. _C.TRAIN.MIN_LR = 5e-6
  146. # Clip gradient norm
  147. _C.TRAIN.CLIP_GRAD = 5.0
  148. # Auto resume from latest checkpoint
  149. _C.TRAIN.AUTO_RESUME = True
  150. # Gradient accumulation steps
  151. # could be overwritten by command line argument
  152. _C.TRAIN.ACCUMULATION_STEPS = 1
  153. # Whether to use gradient checkpointing to save memory
  154. # could be overwritten by command line argument
  155. _C.TRAIN.USE_CHECKPOINT = False
  156. # LR scheduler
  157. _C.TRAIN.LR_SCHEDULER = CN()
  158. _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
  159. # Epoch interval to decay LR, used in StepLRScheduler
  160. _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
  161. # LR decay rate, used in StepLRScheduler
  162. _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
  163. # warmup_prefix used in CosineLRScheduler
  164. _C.TRAIN.LR_SCHEDULER.WARMUP_PREFIX = True
  165. # [SimMIM] Gamma / Multi steps value, used in MultiStepLRScheduler
  166. _C.TRAIN.LR_SCHEDULER.GAMMA = 0.1
  167. _C.TRAIN.LR_SCHEDULER.MULTISTEPS = []
  168. # Optimizer
  169. _C.TRAIN.OPTIMIZER = CN()
  170. _C.TRAIN.OPTIMIZER.NAME = 'adamw'
  171. # Optimizer Epsilon
  172. _C.TRAIN.OPTIMIZER.EPS = 1e-8
  173. # Optimizer Betas
  174. _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
  175. # SGD momentum
  176. _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
  177. # [SimMIM] Layer decay for fine-tuning
  178. _C.TRAIN.LAYER_DECAY = 1.0
  179. # MoE
  180. _C.TRAIN.MOE = CN()
  181. # Only save model on master device
  182. _C.TRAIN.MOE.SAVE_MASTER = False
  183. # -----------------------------------------------------------------------------
  184. # Augmentation settings
  185. # -----------------------------------------------------------------------------
  186. _C.AUG = CN()
  187. # Color jitter factor
  188. _C.AUG.COLOR_JITTER = 0.4
  189. # Use AutoAugment policy. "v0" or "original"
  190. _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
  191. # Random erase prob
  192. _C.AUG.REPROB = 0.25
  193. # Random erase mode
  194. _C.AUG.REMODE = 'pixel'
  195. # Random erase count
  196. _C.AUG.RECOUNT = 1
  197. # Mixup alpha, mixup enabled if > 0
  198. _C.AUG.MIXUP = 0.8
  199. # Cutmix alpha, cutmix enabled if > 0
  200. _C.AUG.CUTMIX = 1.0
  201. # Cutmix min/max ratio, overrides alpha and enables cutmix if set
  202. _C.AUG.CUTMIX_MINMAX = None
  203. # Probability of performing mixup or cutmix when either/both is enabled
  204. _C.AUG.MIXUP_PROB = 1.0
  205. # Probability of switching to cutmix when both mixup and cutmix enabled
  206. _C.AUG.MIXUP_SWITCH_PROB = 0.5
  207. # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
  208. _C.AUG.MIXUP_MODE = 'batch'
  209. # -----------------------------------------------------------------------------
  210. # Testing settings
  211. # -----------------------------------------------------------------------------
  212. _C.TEST = CN()
  213. # Whether to use center crop when testing
  214. _C.TEST.CROP = True
  215. # Whether to use SequentialSampler as validation sampler
  216. _C.TEST.SEQUENTIAL = False
  217. _C.TEST.SHUFFLE = False
  218. # -----------------------------------------------------------------------------
  219. # Misc
  220. # -----------------------------------------------------------------------------
  221. # [SimMIM] Whether to enable pytorch amp, overwritten by command line argument
  222. _C.ENABLE_AMP = False
  223. # Enable Pytorch automatic mixed precision (amp).
  224. _C.AMP_ENABLE = True
  225. # [Deprecated] Mixed precision opt level of apex, if O0, no apex amp is used ('O0', 'O1', 'O2')
  226. _C.AMP_OPT_LEVEL = ''
  227. # Path to output folder, overwritten by command line argument
  228. _C.OUTPUT = ''
  229. # Tag of experiment, overwritten by command line argument
  230. _C.TAG = 'default'
  231. # Frequency to save checkpoint
  232. _C.SAVE_FREQ = 1
  233. # Frequency to logging info
  234. _C.PRINT_FREQ = 10
  235. # Fixed random seed
  236. _C.SEED = 0
  237. # Perform evaluation only, overwritten by command line argument
  238. _C.EVAL_MODE = False
  239. # Test throughput only, overwritten by command line argument
  240. _C.THROUGHPUT_MODE = False
  241. # local rank for DistributedDataParallel, given by command line argument
  242. _C.LOCAL_RANK = 0
  243. # for acceleration
  244. _C.FUSED_WINDOW_PROCESS = False
  245. _C.FUSED_LAYERNORM = False
  246. def _update_config_from_file(config, cfg_file):
  247. config.defrost()
  248. with open(cfg_file, 'r') as f:
  249. yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
  250. for cfg in yaml_cfg.setdefault('BASE', ['']):
  251. if cfg:
  252. _update_config_from_file(
  253. config, os.path.join(os.path.dirname(cfg_file), cfg)
  254. )
  255. print('=> merge config from {}'.format(cfg_file))
  256. config.merge_from_file(cfg_file)
  257. config.freeze()
  258. def update_config(config, args):
  259. _update_config_from_file(config, args.cfg)
  260. config.defrost()
  261. if args.opts:
  262. config.merge_from_list(args.opts)
  263. def _check_args(name):
  264. if hasattr(args, name) and eval(f'args.{name}'):
  265. return True
  266. return False
  267. # merge from specific arguments
  268. if _check_args('batch_size'):
  269. config.DATA.BATCH_SIZE = args.batch_size
  270. if _check_args('data_path'):
  271. config.DATA.DATA_PATH = args.data_path
  272. if _check_args('zip'):
  273. config.DATA.ZIP_MODE = True
  274. if _check_args('cache_mode'):
  275. config.DATA.CACHE_MODE = args.cache_mode
  276. if _check_args('pretrained'):
  277. config.MODEL.PRETRAINED = args.pretrained
  278. if _check_args('resume'):
  279. config.MODEL.RESUME = args.resume
  280. if _check_args('accumulation_steps'):
  281. config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
  282. if _check_args('use_checkpoint'):
  283. config.TRAIN.USE_CHECKPOINT = True
  284. if _check_args('amp_opt_level'):
  285. print("[warning] Apex amp has been deprecated, please use pytorch amp instead!")
  286. if args.amp_opt_level == 'O0':
  287. config.AMP_ENABLE = False
  288. if _check_args('disable_amp'):
  289. config.AMP_ENABLE = False
  290. if _check_args('output'):
  291. config.OUTPUT = args.output
  292. if _check_args('tag'):
  293. config.TAG = args.tag
  294. if _check_args('eval'):
  295. config.EVAL_MODE = True
  296. if _check_args('throughput'):
  297. config.THROUGHPUT_MODE = True
  298. # [SimMIM]
  299. if _check_args('enable_amp'):
  300. config.ENABLE_AMP = args.enable_amp
  301. # for acceleration
  302. if _check_args('fused_window_process'):
  303. config.FUSED_WINDOW_PROCESS = True
  304. if _check_args('fused_layernorm'):
  305. config.FUSED_LAYERNORM = True
  306. ## Overwrite optimizer if not None, currently we use it for [fused_adam, fused_lamb]
  307. if _check_args('optim'):
  308. config.TRAIN.OPTIMIZER.NAME = args.optim
  309. # set local rank for distributed training
  310. if PYTORCH_MAJOR_VERSION == 1:
  311. config.LOCAL_RANK = args.local_rank
  312. else:
  313. config.LOCAL_RANK = int(os.environ['LOCAL_RANK'])
  314. # output folder
  315. config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
  316. config.freeze()
  317. def get_config(args):
  318. """Get a yacs CfgNode object with default values."""
  319. # Return a clone so that the defaults will not be altered
  320. # This is for the "local variable" use pattern
  321. config = _C.clone()
  322. update_config(config, args)
  323. return config