utils.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # --------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. import os
  8. import torch
  9. import torch.distributed as dist
  10. try:
  11. from torch._six import inf
  12. except:
  13. from torch import inf
  14. def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger):
  15. logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
  16. if config.MODEL.RESUME.startswith('https'):
  17. checkpoint = torch.hub.load_state_dict_from_url(
  18. config.MODEL.RESUME, map_location='cpu', check_hash=True)
  19. else:
  20. checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
  21. # checkpoint['model'].pop('head.weight')
  22. # checkpoint['model'].pop('head.bias')
  23. for k,v in checkpoint['model'].items():
  24. print('=====',k,'======',v.shape)
  25. if checkpoint['model']['head.weight'].shape[0]==1000:
  26. checkpoint['model']['head.weight']=torch.nn.Parameter(torch.nn.init.xavier_uniform(torch.empty(2,768)))
  27. checkpoint['model']['head.bias']=torch.nn.Parameter(torch.randn(2))
  28. print('===modify head weight and bias===')
  29. msg = model.load_state_dict(checkpoint['model'], strict=False)
  30. logger.info(msg)
  31. max_accuracy = 0.0
  32. if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
  33. optimizer.load_state_dict(checkpoint['optimizer'])
  34. lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  35. config.defrost()
  36. config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
  37. config.freeze()
  38. if 'scaler' in checkpoint:
  39. loss_scaler.load_state_dict(checkpoint['scaler'])
  40. logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
  41. if 'max_accuracy' in checkpoint:
  42. max_accuracy = checkpoint['max_accuracy']
  43. del checkpoint
  44. torch.cuda.empty_cache()
  45. return max_accuracy
  46. def load_pretrained(config, model, logger):
  47. logger.info(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......")
  48. checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
  49. state_dict = checkpoint['model']
  50. # delete relative_position_index since we always re-init it
  51. relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
  52. for k in relative_position_index_keys:
  53. del state_dict[k]
  54. # delete relative_coords_table since we always re-init it
  55. relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
  56. for k in relative_position_index_keys:
  57. del state_dict[k]
  58. # delete attn_mask since we always re-init it
  59. attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
  60. for k in attn_mask_keys:
  61. del state_dict[k]
  62. # bicubic interpolate relative_position_bias_table if not match
  63. relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
  64. for k in relative_position_bias_table_keys:
  65. relative_position_bias_table_pretrained = state_dict[k]
  66. relative_position_bias_table_current = model.state_dict()[k]
  67. L1, nH1 = relative_position_bias_table_pretrained.size()
  68. L2, nH2 = relative_position_bias_table_current.size()
  69. if nH1 != nH2:
  70. logger.warning(f"Error in loading {k}, passing......")
  71. else:
  72. if L1 != L2:
  73. # bicubic interpolate relative_position_bias_table if not match
  74. S1 = int(L1 ** 0.5)
  75. S2 = int(L2 ** 0.5)
  76. relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
  77. relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
  78. mode='bicubic')
  79. state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
  80. # bicubic interpolate absolute_pos_embed if not match
  81. absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k]
  82. for k in absolute_pos_embed_keys:
  83. # dpe
  84. absolute_pos_embed_pretrained = state_dict[k]
  85. absolute_pos_embed_current = model.state_dict()[k]
  86. _, L1, C1 = absolute_pos_embed_pretrained.size()
  87. _, L2, C2 = absolute_pos_embed_current.size()
  88. if C1 != C1:
  89. logger.warning(f"Error in loading {k}, passing......")
  90. else:
  91. if L1 != L2:
  92. S1 = int(L1 ** 0.5)
  93. S2 = int(L2 ** 0.5)
  94. absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
  95. absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
  96. absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
  97. absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
  98. absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
  99. absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
  100. state_dict[k] = absolute_pos_embed_pretrained_resized
  101. # check classifier, if not match, then re-init classifier to zero
  102. head_bias_pretrained = state_dict['head.bias']
  103. Nc1 = head_bias_pretrained.shape[0]
  104. Nc2 = model.head.bias.shape[0]
  105. if (Nc1 != Nc2):
  106. if Nc1 == 21841 and Nc2 == 1000:
  107. logger.info("loading ImageNet-22K weight to ImageNet-1K ......")
  108. map22kto1k_path = f'data/map22kto1k.txt'
  109. with open(map22kto1k_path) as f:
  110. map22kto1k = f.readlines()
  111. map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
  112. state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
  113. state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
  114. else:
  115. torch.nn.init.constant_(model.head.bias, 0.)
  116. torch.nn.init.constant_(model.head.weight, 0.)
  117. del state_dict['head.weight']
  118. del state_dict['head.bias']
  119. logger.warning(f"Error in loading classifier head, re-init classifier head to 0")
  120. msg = model.load_state_dict(state_dict, strict=False)
  121. logger.warning(msg)
  122. logger.info(f"=> loaded successfully '{config.MODEL.PRETRAINED}'")
  123. del checkpoint
  124. torch.cuda.empty_cache()
  125. def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger):
  126. save_state = {'model': model.state_dict(),
  127. 'optimizer': optimizer.state_dict(),
  128. 'lr_scheduler': lr_scheduler.state_dict(),
  129. 'max_accuracy': max_accuracy,
  130. 'scaler': loss_scaler.state_dict(),
  131. 'epoch': epoch,
  132. 'config': config}
  133. save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
  134. logger.info(f"{save_path} saving......")
  135. torch.save(save_state, save_path)
  136. logger.info(f"{save_path} saved !!!")
  137. def get_grad_norm(parameters, norm_type=2):
  138. if isinstance(parameters, torch.Tensor):
  139. parameters = [parameters]
  140. parameters = list(filter(lambda p: p.grad is not None, parameters))
  141. norm_type = float(norm_type)
  142. total_norm = 0
  143. for p in parameters:
  144. param_norm = p.grad.data.norm(norm_type)
  145. total_norm += param_norm.item() ** norm_type
  146. total_norm = total_norm ** (1. / norm_type)
  147. return total_norm
  148. def auto_resume_helper(output_dir):
  149. checkpoints = os.listdir(output_dir)
  150. checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
  151. print(f"All checkpoints founded in {output_dir}: {checkpoints}")
  152. if len(checkpoints) > 0:
  153. latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
  154. print(f"The latest checkpoint founded: {latest_checkpoint}")
  155. resume_file = latest_checkpoint
  156. else:
  157. resume_file = None
  158. return resume_file
  159. def reduce_tensor(tensor):
  160. rt = tensor.clone()
  161. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  162. rt /= dist.get_world_size()
  163. return rt
  164. def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor:
  165. if isinstance(parameters, torch.Tensor):
  166. parameters = [parameters]
  167. parameters = [p for p in parameters if p.grad is not None]
  168. norm_type = float(norm_type)
  169. if len(parameters) == 0:
  170. return torch.tensor(0.)
  171. device = parameters[0].grad.device
  172. if norm_type == inf:
  173. total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
  174. else:
  175. total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(),
  176. norm_type).to(device) for p in parameters]), norm_type)
  177. return total_norm
  178. class NativeScalerWithGradNormCount:
  179. state_dict_key = "amp_scaler"
  180. def __init__(self):
  181. self._scaler = torch.cuda.amp.GradScaler()
  182. def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
  183. self._scaler.scale(loss).backward(create_graph=create_graph)
  184. if update_grad:
  185. if clip_grad is not None:
  186. assert parameters is not None
  187. self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
  188. norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
  189. else:
  190. self._scaler.unscale_(optimizer)
  191. norm = ampscaler_get_grad_norm(parameters)
  192. self._scaler.step(optimizer)
  193. self._scaler.update()
  194. else:
  195. norm = None
  196. return norm
  197. def state_dict(self):
  198. return self._scaler.state_dict()
  199. def load_state_dict(self, state_dict):
  200. self._scaler.load_state_dict(state_dict)