1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- #!/usr/bin/env python3
- # -*- coding:utf-8 -*-
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- from loguru import logger
- import torch
- import os
- import shutil
- def load_ckpt(model, ckpt):
- model_state_dict = model.state_dict()
- load_dict = {}
- for key_model, v in model_state_dict.items():
- if key_model not in ckpt:
- logger.warning(
- "{} is not in the ckpt. Please double check and see if this is desired.".format(
- key_model
- )
- )
- continue
- v_ckpt = ckpt[key_model]
- if v.shape != v_ckpt.shape:
- logger.warning(
- "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
- key_model, v_ckpt.shape, key_model, v.shape
- )
- )
- continue
- load_dict[key_model] = v_ckpt
- model.load_state_dict(load_dict, strict=False)
- return model
- def save_checkpoint(state, is_best, save_dir, model_name=""):
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
- filename = os.path.join(save_dir, model_name + "_ckpt.pth.tar")
- torch.save(state, filename)
- if is_best:
- best_filename = os.path.join(save_dir, "best_ckpt.pth.tar")
- shutil.copyfile(filename, best_filename)
|