123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- import argparse
- import os
- import time
- from os.path import isfile, join, split
- import numpy as np
- import tqdm
- import yaml
- import cv2
- import jittor as jt
- from jittor import nn
- from logger import Logger
- from dataloader import get_loader
- from model.network import Net
- from skimage.measure import label, regionprops
- from mask_component_utils import reverse_mapping, visulize_mapping, get_boundary_point
- if jt.has_cuda:
- jt.flags.use_cuda = 1
- parser = argparse.ArgumentParser(description='Jittor Semantic-Line Inference')
- parser.add_argument('--config', default="../config.yml", help="path to config file")
- parser.add_argument('--model', required=True, help='path to the pretrained model')
- parser.add_argument('--tmp', default="", help='tmp')
- parser.add_argument('--dump', default=False)
- args = parser.parse_args()
- assert os.path.isfile(args.config)
- CONFIGS = yaml.load(open(args.config))
- # merge configs
- if args.tmp != "" and args.tmp != CONFIGS["MISC"]["TMP"]:
- CONFIGS["MISC"]["TMP"] = args.tmp
- os.makedirs(CONFIGS["MISC"]["TMP"], exist_ok=True)
- logger = Logger(os.path.join(CONFIGS["MISC"]["TMP"], "log.txt"))
- def main():
- logger.info(args)
- model = Net(numAngle=CONFIGS["MODEL"]["NUMANGLE"], numRho=CONFIGS["MODEL"]["NUMRHO"], backbone=CONFIGS["MODEL"]["BACKBONE"])
- if args.model:
- if isfile(args.model):
- logger.info("=> loading pretrained model '{}'".format(args.model))
- import torch
- m = torch.load(args.model)
- if 'state_dict' in m.keys():
- m = m['state_dict']
- torch.save(m, '_temp_model.pth')
- del m
- model.load('_temp_model.pth')
- logger.info("=> loaded checkpoint '{}'".format(args.model))
- else:
- logger.info("=> no pretrained model found at '{}'".format(args.model))
- # dataloader
- test_loader = get_loader(CONFIGS["DATA"]["TEST_DIR"], CONFIGS["DATA"]["TEST_LABEL_FILE"],
- batch_size=1, num_thread=CONFIGS["DATA"]["WORKERS"], test=True)
- logger.info("Data loading done.")
-
- weights_nodes = {}
- data_nodes = {}
- def named_dump_func(name):
- def dump_func(self, inputs, outputs):
- input_name = name + '_input'
- output_name = name + '_output'
- if isinstance(self, nn.Conv2d):
- weights_nodes[name] = self.weight.numpy()
- data_nodes[input_name] = inputs[0].numpy()
- data_nodes[output_name] = outputs[0].numpy()
- return dump_func
- if args.dump:
- logger.info('Add hooks to dump data.')
- for name, module in model.named_modules():
- module.register_forward_hook(named_dump_func(name))
- logger.info("Start testing.")
- total_time = test(test_loader, model, args)
- if args.dump:
- np.save('data_nodes.npy', data_nodes)
- np.save('weights_nodes.npy', weights_nodes)
- exit()
- logger.info("Test done! Total %d imgs at %.4f secs without image io, fps: %.3f" % (len(test_loader), total_time, len(test_loader) / total_time))
-
- def test(test_loader, model, args):
- # switch to evaluate mode
- model.eval()
- bar = tqdm.tqdm(test_loader)
- ftime = 0
- ttime = 0
- ntime = 0
- for i, data in enumerate(bar):
- t = time.time()
- images, names, size = data
-
- images = jt.array(images)
- # size = (size[0].item(), size[1].item())
- key_points = model(images)
- if args.dump:
- break
- key_points = key_points.sigmoid()
- ftime += (time.time() - t)
- visualize_save_path = os.path.join(CONFIGS["MISC"]["TMP"], 'visualize_test')
- os.makedirs(visualize_save_path, exist_ok=True)
- binary_kmap = key_points.squeeze(0).squeeze(0).numpy() > CONFIGS['MODEL']['THRESHOLD']
- kmap_label = label(binary_kmap, connectivity=1)
- props = regionprops(kmap_label)
- plist = []
- for prop in props:
- plist.append(prop.centroid)
- size = (size[0][0], size[0][1])
- b_points = reverse_mapping(plist, numAngle=CONFIGS["MODEL"]["NUMANGLE"], numRho=CONFIGS["MODEL"]["NUMRHO"], size=(400, 400))
- scale_w = size[1] / 400
- scale_h = size[0] / 400
- for i in range(len(b_points)):
- y1 = int(np.round(b_points[i][0] * scale_h))
- x1 = int(np.round(b_points[i][1] * scale_w))
- y2 = int(np.round(b_points[i][2] * scale_h))
- x2 = int(np.round(b_points[i][3] * scale_w))
- if x1 == x2:
- angle = -np.pi / 2
- else:
- angle = np.arctan((y1-y2) / (x1-x2))
- (x1, y1), (x2, y2) = get_boundary_point(y1, x1, angle, size[0], size[1])
- b_points[i] = (y1, x1, y2, x2)
- ttime += (time.time() - t)
-
- vis = visulize_mapping(b_points, size, names[0])
-
- cv2.imwrite(join(visualize_save_path, names[0].split('/')[-1]), vis)
- np_data = np.array(b_points)
- np.save(join(visualize_save_path, names[0].split('/')[-1].split('.')[0]), np_data)
- ntime += (time.time() - t)
- if args.dump:
- return 0
- print('forward fps for total images: %.6f' % (len(test_loader) / ftime))
- print('forward + post-processing fps for total images: %.6f' % (len(test_loader) / ttime))
- print('total fps for total images: %.6f' % (len(test_loader) / ntime))
- return ntime
- if __name__ == '__main__':
- main()
|