compare.py 1.5 KB

12345678910111213141516171819202122232425262728
  1. import torch
  2. model_params_rank0 = torch.load("rank0model.pth",
  3. map_location = lambda storage, loc: storage.cuda(0))
  4. model_params_rank1 = torch.load("rank1model.pth",
  5. map_location = lambda storage, loc: storage.cuda(0))
  6. master_params_rank0 = torch.load("rank0master.pth",
  7. map_location = lambda storage, loc: storage.cuda(0))
  8. master_params_rank1 = torch.load("rank1master.pth",
  9. map_location = lambda storage, loc: storage.cuda(0))
  10. for model_rank0, model_rank1, master_rank0, master_rank1 in zip(
  11. model_params_rank0,
  12. model_params_rank1,
  13. master_params_rank0,
  14. master_params_rank1):
  15. assert torch.allclose(model_rank0, model_rank1), "Model param mismatch"
  16. assert torch.allclose(master_rank0, master_rank1), "Master param mismatch"
  17. # Some debugging/investigation assistance code:
  18. # maxval, maxind = torch.max(((torch.abs(model_rank0).float())/torch.abs(master_rank0)).view(-1), 0)
  19. # offending_val_half = model_rank0.view(-1)[maxind.item()]
  20. # offending_val_float = master_rank0.view(-1)[maxind.item()]
  21. # print(maxval.item(), maxind.item(), offending_val_half.item(), offending_val_float.item(),
  22. # offending_val_float.half().item())
  23. # rtol needs to be > 2^-11 because of denormals...
  24. assert torch.allclose(model_rank0, master_rank0.half(), rtol=.005), "Model-master mismatch"
  25. print("OK: Model and master params match across ranks.")