123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- import argparse
- import os
- import random
- import time
- from os.path import isfile, join, split
- import torch
- import torchvision
- import torch.backends.cudnn as cudnn
- import torch.nn as nn
- import torch.optim
- import numpy as np
- import tqdm
- import yaml
- import cv2
- from torch.optim import lr_scheduler
- from logger import Logger
- from dataloader import get_loader
- from model.network import Net
- from skimage.measure import label, regionprops
- from mask_component_utils import reverse_mapping, visulize_mapping, edge_align, get_boundary_point
- parser = argparse.ArgumentParser(description='PyTorch Semantic-Line Training')
- # arguments from command line
- parser.add_argument('--config', default="./config.yml", help="path to config file")
- parser.add_argument('--model', required=True, help='path to the pretrained model')
- parser.add_argument('--align', default=False, action='store_true')
- parser.add_argument('--tmp', default="", help='tmp')
- args = parser.parse_args()
- assert os.path.isfile(args.config)
- CONFIGS = yaml.load(open(args.config))
- # merge configs
- if args.tmp != "" and args.tmp != CONFIGS["MISC"]["TMP"]:
- CONFIGS["MISC"]["TMP"] = args.tmp
- os.makedirs(CONFIGS["MISC"]["TMP"], exist_ok=True)
- logger = Logger(os.path.join(CONFIGS["MISC"]["TMP"], "log.txt"))
- def main():
- logger.info(args)
- model = Net(numAngle=CONFIGS["MODEL"]["NUMANGLE"], numRho=CONFIGS["MODEL"]["NUMRHO"], backbone=CONFIGS["MODEL"]["BACKBONE"])
- model = model.cuda(device=CONFIGS["TRAIN"]["GPU_ID"])
- if args.model:
- if isfile(args.model):
- logger.info("=> loading pretrained model '{}'".format(args.model))
- checkpoint = torch.load(args.model)
- if 'state_dict' in checkpoint.keys():
- model.load_state_dict(checkpoint['state_dict'])
- else:
- model.load_state_dict(checkpoint)
- logger.info("=> loaded checkpoint '{}'"
- .format(args.model))
- else:
- logger.info("=> no pretrained model found at '{}'".format(args.model))
- # dataloader
- test_loader = get_loader(CONFIGS["DATA"]["TEST_DIR"], CONFIGS["DATA"]["TEST_LABEL_FILE"],
- batch_size=1, num_thread=CONFIGS["DATA"]["WORKERS"], test=True)
- logger.info("Data loading done.")
-
-
- logger.info("Start testing.")
- total_time = test(test_loader, model, args)
-
- logger.info("Test done! Total %d imgs at %.4f secs without image io, fps: %.3f" % (len(test_loader), total_time, len(test_loader) / total_time))
-
- def test(test_loader, model, args):
- # switch to evaluate mode
- model.eval()
- with torch.no_grad():
- bar = tqdm.tqdm(test_loader)
- iter_num = len(test_loader.dataset)
- ftime = 0
- ntime = 0
- for i, data in enumerate(bar):
- t = time.time()
- images, names, size = data
-
- images = images.cuda(device=CONFIGS["TRAIN"]["GPU_ID"])
- # size = (size[0].item(), size[1].item())
- key_points = model(images)
-
- key_points = torch.sigmoid(key_points)
- ftime += (time.time() - t)
- t = time.time()
- visualize_save_path = os.path.join(CONFIGS["MISC"]["TMP"], 'visualize_test')
- os.makedirs(visualize_save_path, exist_ok=True)
- binary_kmap = key_points.squeeze().cpu().numpy() > CONFIGS['MODEL']['THRESHOLD']
- kmap_label = label(binary_kmap, connectivity=1)
- props = regionprops(kmap_label)
- plist = []
- for prop in props:
- plist.append(prop.centroid)
- size = (size[0][0], size[0][1])
- b_points = reverse_mapping(plist, numAngle=CONFIGS["MODEL"]["NUMANGLE"], numRho=CONFIGS["MODEL"]["NUMRHO"], size=(400, 400))
- scale_w = size[1] / 400
- scale_h = size[0] / 400
- for i in range(len(b_points)):
- y1 = int(np.round(b_points[i][0] * scale_h))
- x1 = int(np.round(b_points[i][1] * scale_w))
- y2 = int(np.round(b_points[i][2] * scale_h))
- x2 = int(np.round(b_points[i][3] * scale_w))
- if x1 == x2:
- angle = -np.pi / 2
- else:
- angle = np.arctan((y1-y2) / (x1-x2))
- (x1, y1), (x2, y2) = get_boundary_point(y1, x1, angle, size[0], size[1])
- b_points[i] = (y1, x1, y2, x2)
- vis = visulize_mapping(b_points, size[::-1], names[0])
-
- cv2.imwrite(join(visualize_save_path, names[0].split('/')[-1]), vis)
- np_data = np.array(b_points)
- np.save(join(visualize_save_path, names[0].split('/')[-1].split('.')[0]), np_data)
- if CONFIGS["MODEL"]["EDGE_ALIGN"] and args.align:
- for i in range(len(b_points)):
- b_points[i] = edge_align(b_points[i], names[0], size, division=5)
- vis = visulize_mapping(b_points, size, names[0])
- cv2.imwrite(join(visualize_save_path, names[0].split('/')[-1].split('.')[0]+'_align.png'), vis)
- np_data = np.array(b_points)
- np.save(join(visualize_save_path, names[0].split('/')[-1].split('.')[0]+'_align'), np_data)
- ntime += (time.time() - t)
- print('forward time for total images: %.6f' % ftime)
- print('post-processing time for total images: %.6f' % ntime)
- return ftime + ntime
- if __name__ == '__main__':
- main()
|