forward.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import argparse
  2. import os
  3. import random
  4. import time
  5. from os.path import isfile, join, split
  6. import torch
  7. import torchvision
  8. import torch.backends.cudnn as cudnn
  9. import torch.nn as nn
  10. import torch.optim
  11. import numpy as np
  12. import tqdm
  13. import yaml
  14. import cv2
  15. from torch.optim import lr_scheduler
  16. from logger import Logger
  17. from dataloader import get_loader
  18. from model.network import Net
  19. from skimage.measure import label, regionprops
  20. from mask_component_utils import reverse_mapping, visulize_mapping, edge_align, get_boundary_point
  21. parser = argparse.ArgumentParser(description='PyTorch Semantic-Line Training')
  22. # arguments from command line
  23. parser.add_argument('--config', default="./config.yml", help="path to config file")
  24. parser.add_argument('--model', required=True, help='path to the pretrained model')
  25. parser.add_argument('--align', default=False, action='store_true')
  26. parser.add_argument('--tmp', default="", help='tmp')
  27. args = parser.parse_args()
  28. assert os.path.isfile(args.config)
  29. CONFIGS = yaml.load(open(args.config))
  30. # merge configs
  31. if args.tmp != "" and args.tmp != CONFIGS["MISC"]["TMP"]:
  32. CONFIGS["MISC"]["TMP"] = args.tmp
  33. os.makedirs(CONFIGS["MISC"]["TMP"], exist_ok=True)
  34. logger = Logger(os.path.join(CONFIGS["MISC"]["TMP"], "log.txt"))
  35. def main():
  36. logger.info(args)
  37. model = Net(numAngle=CONFIGS["MODEL"]["NUMANGLE"], numRho=CONFIGS["MODEL"]["NUMRHO"], backbone=CONFIGS["MODEL"]["BACKBONE"])
  38. model = model.cuda(device=CONFIGS["TRAIN"]["GPU_ID"])
  39. if args.model:
  40. if isfile(args.model):
  41. logger.info("=> loading pretrained model '{}'".format(args.model))
  42. checkpoint = torch.load(args.model)
  43. if 'state_dict' in checkpoint.keys():
  44. model.load_state_dict(checkpoint['state_dict'])
  45. else:
  46. model.load_state_dict(checkpoint)
  47. logger.info("=> loaded checkpoint '{}'"
  48. .format(args.model))
  49. else:
  50. logger.info("=> no pretrained model found at '{}'".format(args.model))
  51. # dataloader
  52. test_loader = get_loader(CONFIGS["DATA"]["TEST_DIR"], CONFIGS["DATA"]["TEST_LABEL_FILE"],
  53. batch_size=1, num_thread=CONFIGS["DATA"]["WORKERS"], test=True)
  54. logger.info("Data loading done.")
  55. logger.info("Start testing.")
  56. total_time = test(test_loader, model, args)
  57. logger.info("Test done! Total %d imgs at %.4f secs without image io, fps: %.3f" % (len(test_loader), total_time, len(test_loader) / total_time))
  58. def test(test_loader, model, args):
  59. # switch to evaluate mode
  60. model.eval()
  61. with torch.no_grad():
  62. bar = tqdm.tqdm(test_loader)
  63. iter_num = len(test_loader.dataset)
  64. ftime = 0
  65. ntime = 0
  66. for i, data in enumerate(bar):
  67. t = time.time()
  68. images, names, size = data
  69. images = images.cuda(device=CONFIGS["TRAIN"]["GPU_ID"])
  70. # size = (size[0].item(), size[1].item())
  71. key_points = model(images)
  72. key_points = torch.sigmoid(key_points)
  73. ftime += (time.time() - t)
  74. t = time.time()
  75. visualize_save_path = os.path.join(CONFIGS["MISC"]["TMP"], 'visualize_test')
  76. os.makedirs(visualize_save_path, exist_ok=True)
  77. binary_kmap = key_points.squeeze().cpu().numpy() > CONFIGS['MODEL']['THRESHOLD']
  78. kmap_label = label(binary_kmap, connectivity=1)
  79. props = regionprops(kmap_label)
  80. plist = []
  81. for prop in props:
  82. plist.append(prop.centroid)
  83. size = (size[0][0], size[0][1])
  84. b_points = reverse_mapping(plist, numAngle=CONFIGS["MODEL"]["NUMANGLE"], numRho=CONFIGS["MODEL"]["NUMRHO"], size=(400, 400))
  85. scale_w = size[1] / 400
  86. scale_h = size[0] / 400
  87. for i in range(len(b_points)):
  88. y1 = int(np.round(b_points[i][0] * scale_h))
  89. x1 = int(np.round(b_points[i][1] * scale_w))
  90. y2 = int(np.round(b_points[i][2] * scale_h))
  91. x2 = int(np.round(b_points[i][3] * scale_w))
  92. if x1 == x2:
  93. angle = -np.pi / 2
  94. else:
  95. angle = np.arctan((y1-y2) / (x1-x2))
  96. (x1, y1), (x2, y2) = get_boundary_point(y1, x1, angle, size[0], size[1])
  97. b_points[i] = (y1, x1, y2, x2)
  98. vis = visulize_mapping(b_points, size[::-1], names[0])
  99. cv2.imwrite(join(visualize_save_path, names[0].split('/')[-1]), vis)
  100. np_data = np.array(b_points)
  101. np.save(join(visualize_save_path, names[0].split('/')[-1].split('.')[0]), np_data)
  102. if CONFIGS["MODEL"]["EDGE_ALIGN"] and args.align:
  103. for i in range(len(b_points)):
  104. b_points[i] = edge_align(b_points[i], names[0], size, division=5)
  105. vis = visulize_mapping(b_points, size, names[0])
  106. cv2.imwrite(join(visualize_save_path, names[0].split('/')[-1].split('.')[0]+'_align.png'), vis)
  107. np_data = np.array(b_points)
  108. np.save(join(visualize_save_path, names[0].split('/')[-1].split('.')[0]+'_align'), np_data)
  109. ntime += (time.time() - t)
  110. print('forward time for total images: %.6f' % ftime)
  111. print('post-processing time for total images: %.6f' % ntime)
  112. return ftime + ntime
  113. if __name__ == '__main__':
  114. main()