forward.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import argparse
  2. import os
  3. import time
  4. from os.path import isfile, join, split
  5. import numpy as np
  6. import tqdm
  7. import yaml
  8. import cv2
  9. import jittor as jt
  10. from jittor import nn
  11. from logger import Logger
  12. from dataloader import get_loader
  13. from model.network import Net
  14. from skimage.measure import label, regionprops
  15. from mask_component_utils import reverse_mapping, visulize_mapping, get_boundary_point
  16. if jt.has_cuda:
  17. jt.flags.use_cuda = 1
  18. parser = argparse.ArgumentParser(description='Jittor Semantic-Line Inference')
  19. parser.add_argument('--config', default="../config.yml", help="path to config file")
  20. parser.add_argument('--model', required=True, help='path to the pretrained model')
  21. parser.add_argument('--tmp', default="", help='tmp')
  22. parser.add_argument('--dump', default=False)
  23. args = parser.parse_args()
  24. assert os.path.isfile(args.config)
  25. CONFIGS = yaml.load(open(args.config))
  26. # merge configs
  27. if args.tmp != "" and args.tmp != CONFIGS["MISC"]["TMP"]:
  28. CONFIGS["MISC"]["TMP"] = args.tmp
  29. os.makedirs(CONFIGS["MISC"]["TMP"], exist_ok=True)
  30. logger = Logger(os.path.join(CONFIGS["MISC"]["TMP"], "log.txt"))
  31. def main():
  32. logger.info(args)
  33. model = Net(numAngle=CONFIGS["MODEL"]["NUMANGLE"], numRho=CONFIGS["MODEL"]["NUMRHO"], backbone=CONFIGS["MODEL"]["BACKBONE"])
  34. if args.model:
  35. if isfile(args.model):
  36. logger.info("=> loading pretrained model '{}'".format(args.model))
  37. import torch
  38. m = torch.load(args.model)
  39. if 'state_dict' in m.keys():
  40. m = m['state_dict']
  41. torch.save(m, '_temp_model.pth')
  42. del m
  43. model.load('_temp_model.pth')
  44. logger.info("=> loaded checkpoint '{}'".format(args.model))
  45. else:
  46. logger.info("=> no pretrained model found at '{}'".format(args.model))
  47. # dataloader
  48. test_loader = get_loader(CONFIGS["DATA"]["TEST_DIR"], CONFIGS["DATA"]["TEST_LABEL_FILE"],
  49. batch_size=1, num_thread=CONFIGS["DATA"]["WORKERS"], test=True)
  50. logger.info("Data loading done.")
  51. weights_nodes = {}
  52. data_nodes = {}
  53. def named_dump_func(name):
  54. def dump_func(self, inputs, outputs):
  55. input_name = name + '_input'
  56. output_name = name + '_output'
  57. if isinstance(self, nn.Conv2d):
  58. weights_nodes[name] = self.weight.numpy()
  59. data_nodes[input_name] = inputs[0].numpy()
  60. data_nodes[output_name] = outputs[0].numpy()
  61. return dump_func
  62. if args.dump:
  63. logger.info('Add hooks to dump data.')
  64. for name, module in model.named_modules():
  65. module.register_forward_hook(named_dump_func(name))
  66. logger.info("Start testing.")
  67. total_time = test(test_loader, model, args)
  68. if args.dump:
  69. np.save('data_nodes.npy', data_nodes)
  70. np.save('weights_nodes.npy', weights_nodes)
  71. exit()
  72. logger.info("Test done! Total %d imgs at %.4f secs without image io, fps: %.3f" % (len(test_loader), total_time, len(test_loader) / total_time))
  73. def test(test_loader, model, args):
  74. # switch to evaluate mode
  75. model.eval()
  76. bar = tqdm.tqdm(test_loader)
  77. ftime = 0
  78. ttime = 0
  79. ntime = 0
  80. for i, data in enumerate(bar):
  81. t = time.time()
  82. images, names, size = data
  83. images = jt.array(images)
  84. # size = (size[0].item(), size[1].item())
  85. key_points = model(images)
  86. if args.dump:
  87. break
  88. key_points = key_points.sigmoid()
  89. ftime += (time.time() - t)
  90. visualize_save_path = os.path.join(CONFIGS["MISC"]["TMP"], 'visualize_test')
  91. os.makedirs(visualize_save_path, exist_ok=True)
  92. binary_kmap = key_points.squeeze(0).squeeze(0).numpy() > CONFIGS['MODEL']['THRESHOLD']
  93. kmap_label = label(binary_kmap, connectivity=1)
  94. props = regionprops(kmap_label)
  95. plist = []
  96. for prop in props:
  97. plist.append(prop.centroid)
  98. size = (size[0][0], size[0][1])
  99. b_points = reverse_mapping(plist, numAngle=CONFIGS["MODEL"]["NUMANGLE"], numRho=CONFIGS["MODEL"]["NUMRHO"], size=(400, 400))
  100. scale_w = size[1] / 400
  101. scale_h = size[0] / 400
  102. for i in range(len(b_points)):
  103. y1 = int(np.round(b_points[i][0] * scale_h))
  104. x1 = int(np.round(b_points[i][1] * scale_w))
  105. y2 = int(np.round(b_points[i][2] * scale_h))
  106. x2 = int(np.round(b_points[i][3] * scale_w))
  107. if x1 == x2:
  108. angle = -np.pi / 2
  109. else:
  110. angle = np.arctan((y1-y2) / (x1-x2))
  111. (x1, y1), (x2, y2) = get_boundary_point(y1, x1, angle, size[0], size[1])
  112. b_points[i] = (y1, x1, y2, x2)
  113. ttime += (time.time() - t)
  114. vis = visulize_mapping(b_points, size, names[0])
  115. cv2.imwrite(join(visualize_save_path, names[0].split('/')[-1]), vis)
  116. np_data = np.array(b_points)
  117. np.save(join(visualize_save_path, names[0].split('/')[-1].split('.')[0]), np_data)
  118. ntime += (time.time() - t)
  119. if args.dump:
  120. return 0
  121. print('forward fps for total images: %.6f' % (len(test_loader) / ftime))
  122. print('forward + post-processing fps for total images: %.6f' % (len(test_loader) / ttime))
  123. print('total fps for total images: %.6f' % (len(test_loader) / ntime))
  124. return ntime
  125. if __name__ == '__main__':
  126. main()