utils_moe.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # --------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. import os
  8. import torch
  9. import torch.distributed as dist
  10. def split_moe_model_state_dict(moe_keys, model_state_dict):
  11. moe_model_state_dict = {}
  12. non_moe_model_state_dict = {}
  13. for (k, v) in model_state_dict.items():
  14. if k in moe_keys:
  15. moe_model_state_dict[k] = v
  16. else:
  17. non_moe_model_state_dict[k] = v
  18. return moe_model_state_dict, non_moe_model_state_dict
  19. def merge_moe_model_state_dict(moe_model_state_dict, non_moe_model_state_dict):
  20. model_state_dict = {}
  21. model_state_dict.update(moe_model_state_dict)
  22. model_state_dict.update(non_moe_model_state_dict)
  23. return model_state_dict
  24. def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger):
  25. global_rank = dist.get_rank()
  26. logger.info(f"==============> Rank[{global_rank}] Resuming form {config.MODEL.RESUME}....................")
  27. if config.MODEL.RESUME.endswith(f'.pth'):
  28. if config.TRAIN.MOE.SAVE_MASTER:
  29. resume_path = config.MODEL.RESUME + f'.global'
  30. else:
  31. resume_path = config.MODEL.RESUME + f'.rank{global_rank}'
  32. logger.info(f"===> Rank[{global_rank}] Re-formatting checkpoint name to {resume_path}......")
  33. else:
  34. resume_path = config.MODEL.RESUME
  35. checkpoint = torch.load(resume_path, map_location='cpu')
  36. msg = model.load_state_dict(checkpoint['model'], strict=False)
  37. logger.info(msg)
  38. max_accuracy = 0.0
  39. if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
  40. optimizer.load_state_dict(checkpoint['optimizer'])
  41. lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  42. config.defrost()
  43. config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
  44. config.freeze()
  45. if 'scaler' in checkpoint:
  46. loss_scaler.load_state_dict(checkpoint['scaler'])
  47. logger.info(f"=>Rank[{global_rank}] loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
  48. if 'max_accuracy' in checkpoint:
  49. max_accuracy = checkpoint['max_accuracy']
  50. del checkpoint
  51. torch.cuda.empty_cache()
  52. return max_accuracy
  53. def load_pretrained(config, model, logger):
  54. global_rank = dist.get_rank()
  55. logger.info(f"==============> Rank[{global_rank}] Loading weight {config.MODEL.PRETRAINED} for fine-tuning......")
  56. if config.MODEL.PRETRAINED.endswith(f'.pth'):
  57. if config.TRAIN.MOE.SAVE_MASTER:
  58. pretrained_path = config.MODEL.PRETRAINED + f'.global'
  59. else:
  60. pretrained_path = config.MODEL.PRETRAINED + f'.rank{global_rank}'
  61. logger.info(f"===> Rank[{global_rank}] Re-formatting checkpoint name to {pretrained_path}......")
  62. else:
  63. pretrained_path = config.MODEL.PRETRAINED
  64. if pretrained_path.endswith(f'.rank{global_rank}'):
  65. checkpoint = torch.load(pretrained_path, map_location='cpu')
  66. if os.path.exists(pretrained_path.replace(f'.rank{global_rank}', f'.master')):
  67. checkpoint_master = torch.load(pretrained_path.replace(f'.rank{global_rank}', f'.master'),
  68. map_location='cpu')
  69. state_dict = merge_moe_model_state_dict(checkpoint['model'], checkpoint_master['model'])
  70. else:
  71. state_dict = checkpoint['model']
  72. elif pretrained_path.endswith(f'.pth.global'):
  73. checkpoint = torch.load(pretrained_path, map_location='cpu')
  74. state_dict = checkpoint['model']
  75. else:
  76. raise NotImplementedError(f"{config.MODEL.PRETRAINED} file error...")
  77. # delete relative_position_index since we always re-init it
  78. relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
  79. for k in relative_position_index_keys:
  80. del state_dict[k]
  81. # delete relative_coords_table since we always re-init it
  82. relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
  83. for k in relative_position_index_keys:
  84. del state_dict[k]
  85. # delete attn_mask since we always re-init it
  86. attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
  87. for k in attn_mask_keys:
  88. del state_dict[k]
  89. # bicubic interpolate relative_position_bias_table if not match
  90. relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
  91. for k in relative_position_bias_table_keys:
  92. relative_position_bias_table_pretrained = state_dict[k]
  93. relative_position_bias_table_current = model.state_dict()[k]
  94. L1, nH1 = relative_position_bias_table_pretrained.size()
  95. L2, nH2 = relative_position_bias_table_current.size()
  96. if nH1 != nH2:
  97. logger.warning(f"Error in loading {k}, passing......")
  98. else:
  99. if L1 != L2:
  100. # bicubic interpolate relative_position_bias_table if not match
  101. S1 = int(L1 ** 0.5)
  102. S2 = int(L2 ** 0.5)
  103. relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
  104. relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
  105. mode='bicubic')
  106. state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
  107. # bicubic interpolate absolute_pos_embed if not match
  108. absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k]
  109. for k in absolute_pos_embed_keys:
  110. # dpe
  111. absolute_pos_embed_pretrained = state_dict[k]
  112. absolute_pos_embed_current = model.state_dict()[k]
  113. _, L1, C1 = absolute_pos_embed_pretrained.size()
  114. _, L2, C2 = absolute_pos_embed_current.size()
  115. if C1 != C1:
  116. logger.warning(f"Error in loading {k}, passing......")
  117. else:
  118. if L1 != L2:
  119. S1 = int(L1 ** 0.5)
  120. S2 = int(L2 ** 0.5)
  121. absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
  122. absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
  123. absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
  124. absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
  125. absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
  126. absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
  127. state_dict[k] = absolute_pos_embed_pretrained_resized
  128. # check classifier, if not match, then re-init classifier to zero
  129. head_bias_pretrained = state_dict['head.bias']
  130. Nc1 = head_bias_pretrained.shape[0]
  131. Nc2 = model.head.bias.shape[0]
  132. if (Nc1 != Nc2):
  133. if Nc1 == 21841 and Nc2 == 1000:
  134. logger.info("loading ImageNet-22K weight to ImageNet-1K ......")
  135. map22kto1k_path = f'data/map22kto1k.txt'
  136. with open(map22kto1k_path) as f:
  137. map22kto1k = f.readlines()
  138. map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
  139. state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
  140. state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
  141. else:
  142. torch.nn.init.constant_(model.head.bias, 0.)
  143. torch.nn.init.constant_(model.head.weight, 0.)
  144. del state_dict['head.weight']
  145. del state_dict['head.bias']
  146. logger.warning(f"Error in loading classifier head, re-init classifier head to 0")
  147. msg = model.load_state_dict(state_dict, strict=False)
  148. logger.warning(msg)
  149. logger.info(f"=> loaded successfully '{config.MODEL.PRETRAINED}'")
  150. del checkpoint
  151. torch.cuda.empty_cache()
  152. def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger,
  153. zero_redundancy=False):
  154. global_rank = dist.get_rank()
  155. if zero_redundancy:
  156. if config.TRAIN.MOE.SAVE_MASTER:
  157. save_state = {'model': model.state_dict()}
  158. if global_rank == 0:
  159. save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.global')
  160. logger.info(f"{save_path} saving......")
  161. torch.save(save_state, save_path)
  162. logger.info(f"{save_path} saved !!!")
  163. else:
  164. moe_model_state_dict, non_moe_model_state_dict = \
  165. split_moe_model_state_dict(model._ddp_params_and_buffers_to_ignore, model.state_dict())
  166. save_state = {'model': moe_model_state_dict}
  167. save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.rank{global_rank}')
  168. logger.info(f"{save_path} saving......")
  169. torch.save(save_state, save_path)
  170. logger.info(f"{save_path} saved !!!")
  171. if global_rank == 0:
  172. save_state_master = {'model': non_moe_model_state_dict}
  173. save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.master')
  174. logger.info(f"{save_path} saving......")
  175. torch.save(save_state_master, save_path)
  176. logger.info(f"{save_path} saved !!!")
  177. else:
  178. save_state = {'model': model.state_dict(),
  179. 'optimizer': optimizer.state_dict(),
  180. 'lr_scheduler': lr_scheduler.state_dict(),
  181. 'max_accuracy': max_accuracy,
  182. 'scaler': loss_scaler.state_dict(),
  183. 'epoch': epoch,
  184. 'config': config}
  185. if config.TRAIN.MOE.SAVE_MASTER:
  186. if global_rank == 0:
  187. save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.global')
  188. logger.info(f"{save_path} saving......")
  189. torch.save(save_state, save_path)
  190. logger.info(f"{save_path} saved !!!")
  191. else:
  192. save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.rank{global_rank}')
  193. logger.info(f"{save_path} saving......")
  194. torch.save(save_state, save_path)
  195. logger.info(f"{save_path} saved !!!")
  196. def auto_resume_helper(output_dir, save_master=False):
  197. global_rank = dist.get_rank()
  198. checkpoints = os.listdir(output_dir)
  199. if not save_master:
  200. master_checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith(f'pth.rank0')]
  201. else:
  202. master_checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith(f'pth.global')]
  203. print(f"All master checkpoints founded in {output_dir}: {master_checkpoints}")
  204. if len(master_checkpoints) > 0:
  205. latest_master_checkpoint = max([os.path.join(output_dir, d) for d in master_checkpoints], key=os.path.getmtime)
  206. latest_checkpoint = latest_master_checkpoint.replace('pth.rank0', f'pth.rank{global_rank}')
  207. print(f"The latest checkpoint founded: {latest_checkpoint}")
  208. resume_file = latest_checkpoint
  209. else:
  210. resume_file = None
  211. return resume_file
  212. def hook_scale_grad(scale, tensor):
  213. return tensor / scale