benchmark.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import argparse
  2. import os
  3. import random
  4. import time
  5. from os.path import isfile, join, split
  6. import numpy as np
  7. import tqdm
  8. import yaml
  9. import cv2
  10. import jittor as jt
  11. from jittor import nn
  12. from logger import Logger
  13. from dataloader import get_loader
  14. from model.network import Net
  15. from skimage.measure import label, regionprops
  16. from mask_component_utils import reverse_mapping, visulize_mapping, get_boundary_point
  17. if jt.has_cuda:
  18. jt.flags.use_cuda = 1
  19. parser = argparse.ArgumentParser(description='Jittor Semantic-Line Inference')
  20. parser.add_argument('--config', default="./config.yml", help="path to config file")
  21. parser.add_argument('--model', required=True, help='path to the pretrained model')
  22. parser.add_argument('--tmp', default="", help='tmp')
  23. parser.add_argument('--dump', default=False)
  24. args = parser.parse_args()
  25. assert os.path.isfile(args.config)
  26. CONFIGS = yaml.load(open(args.config))
  27. # merge configs
  28. if args.tmp != "" and args.tmp != CONFIGS["MISC"]["TMP"]:
  29. CONFIGS["MISC"]["TMP"] = args.tmp
  30. os.makedirs(CONFIGS["MISC"]["TMP"], exist_ok=True)
  31. logger = Logger(os.path.join(CONFIGS["MISC"]["TMP"], "log.txt"))
  32. def main():
  33. logger.info(args)
  34. model = Net(numAngle=CONFIGS["MODEL"]["NUMANGLE"], numRho=CONFIGS["MODEL"]["NUMRHO"], backbone=CONFIGS["MODEL"]["BACKBONE"])
  35. if args.model:
  36. if isfile(args.model):
  37. import torch
  38. m = torch.load(args.model)
  39. if 'state_dict' in m.keys():
  40. m = m['state_dict']
  41. torch.save(m, '_temp_model.pth')
  42. del m
  43. logger.info("=> loading pretrained model '{}'".format(args.model))
  44. #model.load('_temp_model.pth')
  45. logger.info("=> loaded checkpoint '{}'".format(args.model))
  46. else:
  47. logger.info("=> no pretrained model found at '{}'".format(args.model))
  48. # dataloader
  49. test_loader = get_loader(CONFIGS["DATA"]["TEST_DIR"], CONFIGS["DATA"]["TEST_LABEL_FILE"],
  50. batch_size=int(os.environ.get("BS","1")), num_thread=CONFIGS["DATA"]["WORKERS"], test=True)
  51. logger.info("Data loading done.")
  52. weights_nodes = {}
  53. data_nodes = {}
  54. def named_dump_func(name):
  55. def dump_func(self, inputs, outputs):
  56. input_name = name + '_input'
  57. output_name = name + '_output'
  58. if isinstance(self, nn.Conv2d):
  59. weights_nodes[name] = self.weight.numpy()
  60. data_nodes[input_name] = inputs[0].numpy()
  61. data_nodes[output_name] = outputs[0].numpy()
  62. return dump_func
  63. if args.dump:
  64. logger.info('Add hooks to dump data.')
  65. for name, module in model.named_modules():
  66. print(name)
  67. module.register_forward_hook(named_dump_func(name))
  68. test(test_loader, model, args)
  69. @jt.no_grad()
  70. def test(test_loader, model, args):
  71. # switch to evaluate mode
  72. model.eval()
  73. for data in test_loader:
  74. images, names, size = data
  75. break
  76. jt.sync_all(True)
  77. # warmup
  78. for i in range(10):
  79. model(images).sync()
  80. jt.sync_all(True)
  81. # rerun
  82. t = time.time()
  83. for i in range(300):
  84. print(i, i/(time.time()-t))
  85. model(images).sync()
  86. jt.sync_all(True)
  87. t = time.time()-t
  88. print("BS:", images.shape[0], "FPS:", 300*images.shape[0]/t)
  89. if __name__ == '__main__':
  90. main()