utils_simmim.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # --------------------------------------------------------
  2. # SimMIM
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # Modified by Zhenda Xie
  7. # --------------------------------------------------------
  8. import os
  9. import torch
  10. import torch.distributed as dist
  11. import numpy as np
  12. from scipy import interpolate
  13. def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger):
  14. logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........")
  15. if config.MODEL.RESUME.startswith('https'):
  16. checkpoint = torch.hub.load_state_dict_from_url(
  17. config.MODEL.RESUME, map_location='cpu', check_hash=True)
  18. else:
  19. checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
  20. # re-map keys due to name change (only for loading provided models)
  21. rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k]
  22. for k in rpe_mlp_keys:
  23. checkpoint['model'][k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k)
  24. msg = model.load_state_dict(checkpoint['model'], strict=False)
  25. logger.info(msg)
  26. max_accuracy = 0.0
  27. if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'scaler' in checkpoint and 'epoch' in checkpoint:
  28. optimizer.load_state_dict(checkpoint['optimizer'])
  29. lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  30. scaler.load_state_dict(checkpoint['scaler'])
  31. config.defrost()
  32. config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
  33. config.freeze()
  34. logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
  35. if 'max_accuracy' in checkpoint:
  36. max_accuracy = checkpoint['max_accuracy']
  37. else:
  38. max_accuracy = 0.0
  39. del checkpoint
  40. torch.cuda.empty_cache()
  41. return max_accuracy
  42. def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, scaler, logger):
  43. save_state = {'model': model.state_dict(),
  44. 'optimizer': optimizer.state_dict(),
  45. 'lr_scheduler': lr_scheduler.state_dict(),
  46. 'scaler': scaler.state_dict(),
  47. 'max_accuracy': max_accuracy,
  48. 'epoch': epoch,
  49. 'config': config}
  50. save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
  51. logger.info(f"{save_path} saving......")
  52. torch.save(save_state, save_path)
  53. logger.info(f"{save_path} saved !!!")
  54. def get_grad_norm(parameters, norm_type=2):
  55. if isinstance(parameters, torch.Tensor):
  56. parameters = [parameters]
  57. parameters = list(filter(lambda p: p.grad is not None, parameters))
  58. norm_type = float(norm_type)
  59. total_norm = 0
  60. for p in parameters:
  61. param_norm = p.grad.data.norm(norm_type)
  62. total_norm += param_norm.item() ** norm_type
  63. total_norm = total_norm ** (1. / norm_type)
  64. return total_norm
  65. def auto_resume_helper(output_dir, logger):
  66. checkpoints = os.listdir(output_dir)
  67. checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
  68. logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}")
  69. if len(checkpoints) > 0:
  70. latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
  71. logger.info(f"The latest checkpoint founded: {latest_checkpoint}")
  72. resume_file = latest_checkpoint
  73. else:
  74. resume_file = None
  75. return resume_file
  76. def reduce_tensor(tensor):
  77. rt = tensor.clone()
  78. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  79. rt /= dist.get_world_size()
  80. return rt
  81. def load_pretrained(config, model, logger):
  82. logger.info(f">>>>>>>>>> Fine-tuned from {config.MODEL.PRETRAINED} ..........")
  83. checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
  84. checkpoint_model = checkpoint['model']
  85. if any([True if 'encoder.' in k else False for k in checkpoint_model.keys()]):
  86. checkpoint_model = {k.replace('encoder.', ''): v for k, v in checkpoint_model.items() if k.startswith('encoder.')}
  87. logger.info('Detect pre-trained model, remove [encoder.] prefix.')
  88. else:
  89. logger.info('Detect non-pre-trained model, pass without doing anything.')
  90. if config.MODEL.TYPE in ['swin', 'swinv2']:
  91. logger.info(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........")
  92. checkpoint = remap_pretrained_keys_swin(model, checkpoint_model, logger)
  93. else:
  94. raise NotImplementedError
  95. msg = model.load_state_dict(checkpoint_model, strict=False)
  96. logger.info(msg)
  97. del checkpoint
  98. torch.cuda.empty_cache()
  99. logger.info(f">>>>>>>>>> loaded successfully '{config.MODEL.PRETRAINED}'")
  100. def remap_pretrained_keys_swin(model, checkpoint_model, logger):
  101. state_dict = model.state_dict()
  102. # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size
  103. all_keys = list(checkpoint_model.keys())
  104. for key in all_keys:
  105. if "relative_position_bias_table" in key:
  106. relative_position_bias_table_pretrained = checkpoint_model[key]
  107. relative_position_bias_table_current = state_dict[key]
  108. L1, nH1 = relative_position_bias_table_pretrained.size()
  109. L2, nH2 = relative_position_bias_table_current.size()
  110. if nH1 != nH2:
  111. logger.info(f"Error in loading {key}, passing......")
  112. else:
  113. if L1 != L2:
  114. logger.info(f"{key}: Interpolate relative_position_bias_table using geo.")
  115. src_size = int(L1 ** 0.5)
  116. dst_size = int(L2 ** 0.5)
  117. def geometric_progression(a, r, n):
  118. return a * (1.0 - r ** n) / (1.0 - r)
  119. left, right = 1.01, 1.5
  120. while right - left > 1e-6:
  121. q = (left + right) / 2.0
  122. gp = geometric_progression(1, q, src_size // 2)
  123. if gp > dst_size // 2:
  124. right = q
  125. else:
  126. left = q
  127. # if q > 1.090307:
  128. # q = 1.090307
  129. dis = []
  130. cur = 1
  131. for i in range(src_size // 2):
  132. dis.append(cur)
  133. cur += q ** (i + 1)
  134. r_ids = [-_ for _ in reversed(dis)]
  135. x = r_ids + [0] + dis
  136. y = r_ids + [0] + dis
  137. t = dst_size // 2.0
  138. dx = np.arange(-t, t + 0.1, 1.0)
  139. dy = np.arange(-t, t + 0.1, 1.0)
  140. logger.info("Original positions = %s" % str(x))
  141. logger.info("Target positions = %s" % str(dx))
  142. all_rel_pos_bias = []
  143. for i in range(nH1):
  144. z = relative_position_bias_table_pretrained[:, i].view(src_size, src_size).float().numpy()
  145. f_cubic = interpolate.interp2d(x, y, z, kind='cubic')
  146. all_rel_pos_bias.append(torch.Tensor(f_cubic(dx, dy)).contiguous().view(-1, 1).to(
  147. relative_position_bias_table_pretrained.device))
  148. new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
  149. checkpoint_model[key] = new_rel_pos_bias
  150. # delete relative_position_index since we always re-init it
  151. relative_position_index_keys = [k for k in checkpoint_model.keys() if "relative_position_index" in k]
  152. for k in relative_position_index_keys:
  153. del checkpoint_model[k]
  154. # delete relative_coords_table since we always re-init it
  155. relative_coords_table_keys = [k for k in checkpoint_model.keys() if "relative_coords_table" in k]
  156. for k in relative_coords_table_keys:
  157. del checkpoint_model[k]
  158. # re-map keys due to name change
  159. rpe_mlp_keys = [k for k in checkpoint_model.keys() if "rpe_mlp" in k]
  160. for k in rpe_mlp_keys:
  161. checkpoint_model[k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint_model.pop(k)
  162. # delete attn_mask since we always re-init it
  163. attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k]
  164. for k in attn_mask_keys:
  165. del checkpoint_model[k]
  166. return checkpoint_model