3
0

compare.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import argparse
  2. import torch
  3. parser = argparse.ArgumentParser(description='Compare')
  4. parser.add_argument('--opt-level', type=str)
  5. parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
  6. parser.add_argument('--loss-scale', type=str, default=None)
  7. parser.add_argument('--fused-adam', action='store_true')
  8. parser.add_argument('--use_baseline', action='store_true')
  9. args = parser.parse_args()
  10. base_file = str(args.opt_level) + "_" +\
  11. str(args.loss_scale) + "_" +\
  12. str(args.keep_batchnorm_fp32) + "_" +\
  13. str(args.fused_adam)
  14. file_e = "True_" + base_file
  15. file_p = "False_" + base_file
  16. if args.use_baseline:
  17. file_b = "baselines/True_" + base_file
  18. dict_e = torch.load(file_e)
  19. dict_p = torch.load(file_p)
  20. if args.use_baseline:
  21. dict_b = torch.load(file_b)
  22. torch.set_printoptions(precision=10)
  23. print(file_e)
  24. print(file_p)
  25. if args.use_baseline:
  26. print(file_b)
  27. # ugly duplication here...
  28. if not args.use_baseline:
  29. for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])):
  30. assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p)
  31. loss_e = dict_e["Loss"][n]
  32. loss_p = dict_p["Loss"][n]
  33. assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p)
  34. print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format(
  35. i_e,
  36. loss_e,
  37. loss_p,
  38. dict_e["Speed"][n],
  39. dict_p["Speed"][n]))
  40. else:
  41. for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])):
  42. assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p)
  43. loss_e = dict_e["Loss"][n]
  44. loss_p = dict_p["Loss"][n]
  45. loss_b = dict_b["Loss"][n]
  46. assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p)
  47. assert loss_e == loss_b, "Iteration {}, loss_e = {}, loss_b = {}".format(i_e, loss_e, loss_b)
  48. print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format(
  49. i_e,
  50. loss_b,
  51. loss_e,
  52. loss_p,
  53. dict_b["Speed"][n],
  54. dict_e["Speed"][n],
  55. dict_p["Speed"][n]))