infer.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import yaml
  16. import glob
  17. import json
  18. from pathlib import Path
  19. from functools import reduce
  20. import cv2
  21. import numpy as np
  22. import math
  23. import paddle
  24. from paddle.inference import Config
  25. from paddle.inference import create_predictor
  26. import sys
  27. # add deploy path of PadleDetection to sys.path
  28. parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
  29. sys.path.insert(0, parent_path)
  30. from benchmark_utils import PaddleInferBenchmark
  31. from picodet_postprocess import PicoDetPostProcess
  32. from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image
  33. from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
  34. from visualize import visualize_box_mask
  35. from dependence.PaddleDetection.deploy.python.utils import argsparser, Timer, get_current_memory_mb
  36. # Global dictionary
  37. SUPPORT_MODELS = {
  38. 'YOLO',
  39. 'RCNN',
  40. 'SSD',
  41. 'Face',
  42. 'FCOS',
  43. 'SOLOv2',
  44. 'TTFNet',
  45. 'S2ANet',
  46. 'JDE',
  47. 'FairMOT',
  48. 'DeepSORT',
  49. 'GFL',
  50. 'PicoDet',
  51. 'CenterNet',
  52. 'TOOD',
  53. 'RetinaNet',
  54. 'StrongBaseline',
  55. 'STGCN',
  56. 'YOLOX',
  57. }
  58. def bench_log(detector, img_list, model_info, batch_size=1, name=None):
  59. mems = {
  60. 'cpu_rss_mb': detector.cpu_mem / len(img_list),
  61. 'gpu_rss_mb': detector.gpu_mem / len(img_list),
  62. 'gpu_util': detector.gpu_util * 100 / len(img_list)
  63. }
  64. perf_info = detector.det_times.report(average=True)
  65. data_info = {
  66. 'batch_size': batch_size,
  67. 'shape': "dynamic_shape",
  68. 'data_num': perf_info['img_num']
  69. }
  70. log = PaddleInferBenchmark(detector.config, model_info, data_info,
  71. perf_info, mems)
  72. log(name)
  73. class Detector(object):
  74. """
  75. Args:
  76. pred_config (object): config of model, defined by `Config(model_dir)`
  77. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  78. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  79. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  80. batch_size (int): size of pre batch in inference
  81. trt_min_shape (int): min shape for dynamic shape in trt
  82. trt_max_shape (int): max shape for dynamic shape in trt
  83. trt_opt_shape (int): opt shape for dynamic shape in trt
  84. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  85. calibration, trt_calib_mode need to set True
  86. cpu_threads (int): cpu threads
  87. enable_mkldnn (bool): whether to open MKLDNN
  88. enable_mkldnn_bfloat16 (bool): whether to turn on mkldnn bfloat16
  89. output_dir (str): The path of output
  90. threshold (float): The threshold of score for visualization
  91. delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
  92. Used by action model.
  93. """
  94. def __init__(self,
  95. model_dir,
  96. device='CPU',
  97. run_mode='paddle',
  98. batch_size=1,
  99. trt_min_shape=1,
  100. trt_max_shape=1280,
  101. trt_opt_shape=640,
  102. trt_calib_mode=False,
  103. cpu_threads=1,
  104. enable_mkldnn=False,
  105. enable_mkldnn_bfloat16=False,
  106. output_dir='output',
  107. threshold=0.5,
  108. delete_shuffle_pass=False):
  109. self.pred_config = self.set_config(model_dir)
  110. self.predictor, self.config = load_predictor(
  111. model_dir,
  112. run_mode=run_mode,
  113. batch_size=batch_size,
  114. min_subgraph_size=self.pred_config.min_subgraph_size,
  115. device=device,
  116. use_dynamic_shape=self.pred_config.use_dynamic_shape,
  117. trt_min_shape=trt_min_shape,
  118. trt_max_shape=trt_max_shape,
  119. trt_opt_shape=trt_opt_shape,
  120. trt_calib_mode=trt_calib_mode,
  121. cpu_threads=cpu_threads,
  122. enable_mkldnn=enable_mkldnn,
  123. enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
  124. delete_shuffle_pass=delete_shuffle_pass)
  125. self.det_times = Timer()
  126. self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
  127. self.batch_size = batch_size
  128. self.output_dir = output_dir
  129. self.threshold = threshold
  130. def set_config(self, model_dir):
  131. return PredictConfig(model_dir)
  132. def preprocess(self, image_list):
  133. preprocess_ops = []
  134. for op_info in self.pred_config.preprocess_infos:
  135. new_op_info = op_info.copy()
  136. op_type = new_op_info.pop('type')
  137. preprocess_ops.append(eval(op_type)(**new_op_info))
  138. input_im_lst = []
  139. input_im_info_lst = []
  140. for im_path in image_list:
  141. im, im_info = preprocess(im_path, preprocess_ops)
  142. input_im_lst.append(im)
  143. input_im_info_lst.append(im_info)
  144. inputs = create_inputs(input_im_lst, input_im_info_lst)
  145. input_names = self.predictor.get_input_names()
  146. for i in range(len(input_names)):
  147. input_tensor = self.predictor.get_input_handle(input_names[i])
  148. input_tensor.copy_from_cpu(inputs[input_names[i]])
  149. return inputs
  150. def postprocess(self, inputs, result):
  151. # postprocess output of predictor
  152. np_boxes_num = result['boxes_num']
  153. if np_boxes_num[0] <= 0:
  154. print('[WARNNING] No object detected.')
  155. result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
  156. result = {k: v for k, v in result.items() if v is not None}
  157. return result
  158. def filter_box(self, result, threshold):
  159. np_boxes_num = result['boxes_num']
  160. boxes = result['boxes']
  161. start_idx = 0
  162. filter_boxes = []
  163. filter_num = []
  164. for i in range(len(np_boxes_num)):
  165. boxes_num = np_boxes_num[i]
  166. boxes_i = boxes[start_idx:start_idx + boxes_num, :]
  167. idx = boxes_i[:, 1] > threshold
  168. filter_boxes_i = boxes_i[idx, :]
  169. filter_boxes.append(filter_boxes_i)
  170. filter_num.append(filter_boxes_i.shape[0])
  171. start_idx += boxes_num
  172. boxes = np.concatenate(filter_boxes)
  173. filter_num = np.array(filter_num)
  174. filter_res = {'boxes': boxes, 'boxes_num': filter_num}
  175. return filter_res
  176. def predict(self, repeats=1):
  177. '''
  178. Args:
  179. repeats (int): repeats number for prediction
  180. Returns:
  181. result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
  182. matix element:[class, score, x_min, y_min, x_max, y_max]
  183. MaskRCNN's result include 'masks': np.ndarray:
  184. shape: [N, im_h, im_w]
  185. '''
  186. # model prediction
  187. np_boxes, np_masks = None, None
  188. for i in range(repeats):
  189. self.predictor.run()
  190. output_names = self.predictor.get_output_names()
  191. boxes_tensor = self.predictor.get_output_handle(output_names[0])
  192. np_boxes = boxes_tensor.copy_to_cpu()
  193. boxes_num = self.predictor.get_output_handle(output_names[1])
  194. np_boxes_num = boxes_num.copy_to_cpu()
  195. if self.pred_config.mask:
  196. masks_tensor = self.predictor.get_output_handle(output_names[2])
  197. np_masks = masks_tensor.copy_to_cpu()
  198. result = dict(boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num)
  199. return result
  200. def merge_batch_result(self, batch_result):
  201. if len(batch_result) == 1:
  202. return batch_result[0]
  203. res_key = batch_result[0].keys()
  204. results = {k: [] for k in res_key}
  205. for res in batch_result:
  206. for k, v in res.items():
  207. results[k].append(v)
  208. for k, v in results.items():
  209. if k != 'masks':
  210. results[k] = np.concatenate(v)
  211. return results
  212. def get_timer(self):
  213. return self.det_times
  214. def predict_image(self,
  215. image_list,
  216. run_benchmark=False,
  217. repeats=1,
  218. visual=True,
  219. save_file=None):
  220. batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
  221. results = []
  222. for i in range(batch_loop_cnt):
  223. start_index = i * self.batch_size
  224. end_index = min((i + 1) * self.batch_size, len(image_list))
  225. batch_image_list = image_list[start_index:end_index]
  226. if run_benchmark:
  227. # preprocess
  228. inputs = self.preprocess(batch_image_list) # warmup
  229. self.det_times.preprocess_time_s.start()
  230. inputs = self.preprocess(batch_image_list)
  231. self.det_times.preprocess_time_s.end()
  232. # model prediction
  233. result = self.predict(repeats=50) # warmup
  234. self.det_times.inference_time_s.start()
  235. result = self.predict(repeats=repeats)
  236. self.det_times.inference_time_s.end(repeats=repeats)
  237. # postprocess
  238. result_warmup = self.postprocess(inputs, result) # warmup
  239. self.det_times.postprocess_time_s.start()
  240. result = self.postprocess(inputs, result)
  241. self.det_times.postprocess_time_s.end()
  242. self.det_times.img_num += len(batch_image_list)
  243. cm, gm, gu = get_current_memory_mb()
  244. self.cpu_mem += cm
  245. self.gpu_mem += gm
  246. self.gpu_util += gu
  247. else:
  248. # preprocess
  249. self.det_times.preprocess_time_s.start()
  250. inputs = self.preprocess(batch_image_list)
  251. self.det_times.preprocess_time_s.end()
  252. # model prediction
  253. self.det_times.inference_time_s.start()
  254. result = self.predict()
  255. self.det_times.inference_time_s.end()
  256. # postprocess
  257. self.det_times.postprocess_time_s.start()
  258. result = self.postprocess(inputs, result)
  259. self.det_times.postprocess_time_s.end()
  260. self.det_times.img_num += len(batch_image_list)
  261. if visual:
  262. visualize(
  263. batch_image_list,
  264. result,
  265. self.pred_config.labels,
  266. output_dir=self.output_dir,
  267. threshold=self.threshold)
  268. results.append(result)
  269. if visual:
  270. print('Test iter {}'.format(i))
  271. if save_file is not None:
  272. Path(self.output_dir).mkdir(exist_ok=True)
  273. self.format_coco_results(image_list, results, save_file=save_file)
  274. results = self.merge_batch_result(results)
  275. return results
  276. def predict_video(self, video_file, camera_id):
  277. video_out_name = 'output.mp4'
  278. if camera_id != -1:
  279. capture = cv2.VideoCapture(camera_id)
  280. else:
  281. capture = cv2.VideoCapture(video_file)
  282. video_out_name = os.path.split(video_file)[-1]
  283. # Get Video info : resolution, fps, frame count
  284. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  285. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  286. fps = int(capture.get(cv2.CAP_PROP_FPS))
  287. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  288. print("fps: %d, frame_count: %d" % (fps, frame_count))
  289. if not os.path.exists(self.output_dir):
  290. os.makedirs(self.output_dir)
  291. out_path = os.path.join(self.output_dir, video_out_name)
  292. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  293. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  294. index = 1
  295. while (1):
  296. ret, frame = capture.read()
  297. if not ret:
  298. break
  299. print('detect frame: %d' % (index))
  300. index += 1
  301. results = self.predict_image([frame[:, :, ::-1]], visual=False)
  302. im = visualize_box_mask(
  303. frame,
  304. results,
  305. self.pred_config.labels,
  306. threshold=self.threshold)
  307. im = np.array(im)
  308. writer.write(im)
  309. if camera_id != -1:
  310. cv2.imshow('Mask Detection', im)
  311. if cv2.waitKey(1) & 0xFF == ord('q'):
  312. break
  313. writer.release()
  314. @staticmethod
  315. def format_coco_results(image_list, results, save_file=None):
  316. coco_results = []
  317. image_id = 0
  318. for result in results:
  319. start_idx = 0
  320. for box_num in result['boxes_num']:
  321. idx_slice = slice(start_idx, start_idx + box_num)
  322. start_idx += box_num
  323. image_file = image_list[image_id]
  324. image_id += 1
  325. if 'boxes' in result:
  326. boxes = result['boxes'][idx_slice, :]
  327. per_result = [
  328. {
  329. 'image_file': image_file,
  330. 'bbox':
  331. [box[2], box[3], box[4] - box[2],
  332. box[5] - box[3]], # xyxy -> xywh
  333. 'score': box[1],
  334. 'category_id': int(box[0]),
  335. } for k, box in enumerate(boxes.tolist())
  336. ]
  337. elif 'segm' in result:
  338. import pycocotools.mask as mask_util
  339. scores = result['score'][idx_slice].tolist()
  340. category_ids = result['label'][idx_slice].tolist()
  341. segms = result['segm'][idx_slice, :]
  342. rles = [
  343. mask_util.encode(
  344. np.array(
  345. mask[:, :, np.newaxis],
  346. dtype=np.uint8,
  347. order='F'))[0] for mask in segms
  348. ]
  349. for rle in rles:
  350. rle['counts'] = rle['counts'].decode('utf-8')
  351. per_result = [{
  352. 'image_file': image_file,
  353. 'segmentation': rle,
  354. 'score': scores[k],
  355. 'category_id': category_ids[k],
  356. } for k, rle in enumerate(rles)]
  357. else:
  358. raise RuntimeError('')
  359. # per_result = [item for item in per_result if item['score'] > threshold]
  360. coco_results.extend(per_result)
  361. if save_file:
  362. with open(os.path.join(save_file), 'w') as f:
  363. json.dump(coco_results, f)
  364. return coco_results
  365. class DetectorSOLOv2(Detector):
  366. """
  367. Args:
  368. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  369. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  370. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  371. batch_size (int): size of pre batch in inference
  372. trt_min_shape (int): min shape for dynamic shape in trt
  373. trt_max_shape (int): max shape for dynamic shape in trt
  374. trt_opt_shape (int): opt shape for dynamic shape in trt
  375. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  376. calibration, trt_calib_mode need to set True
  377. cpu_threads (int): cpu threads
  378. enable_mkldnn (bool): whether to open MKLDNN
  379. enable_mkldnn_bfloat16 (bool): Whether to turn on mkldnn bfloat16
  380. output_dir (str): The path of output
  381. threshold (float): The threshold of score for visualization
  382. """
  383. def __init__(
  384. self,
  385. model_dir,
  386. device='CPU',
  387. run_mode='paddle',
  388. batch_size=1,
  389. trt_min_shape=1,
  390. trt_max_shape=1280,
  391. trt_opt_shape=640,
  392. trt_calib_mode=False,
  393. cpu_threads=1,
  394. enable_mkldnn=False,
  395. enable_mkldnn_bfloat16=False,
  396. output_dir='./',
  397. threshold=0.5, ):
  398. super(DetectorSOLOv2, self).__init__(
  399. model_dir=model_dir,
  400. device=device,
  401. run_mode=run_mode,
  402. batch_size=batch_size,
  403. trt_min_shape=trt_min_shape,
  404. trt_max_shape=trt_max_shape,
  405. trt_opt_shape=trt_opt_shape,
  406. trt_calib_mode=trt_calib_mode,
  407. cpu_threads=cpu_threads,
  408. enable_mkldnn=enable_mkldnn,
  409. enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
  410. output_dir=output_dir,
  411. threshold=threshold, )
  412. def predict(self, repeats=1):
  413. '''
  414. Args:
  415. repeats (int): repeat number for prediction
  416. Returns:
  417. result (dict): 'segm': np.ndarray,shape:[N, im_h, im_w]
  418. 'cate_label': label of segm, shape:[N]
  419. 'cate_score': confidence score of segm, shape:[N]
  420. '''
  421. np_label, np_score, np_segms = None, None, None
  422. for i in range(repeats):
  423. self.predictor.run()
  424. output_names = self.predictor.get_output_names()
  425. np_boxes_num = self.predictor.get_output_handle(output_names[
  426. 0]).copy_to_cpu()
  427. np_label = self.predictor.get_output_handle(output_names[
  428. 1]).copy_to_cpu()
  429. np_score = self.predictor.get_output_handle(output_names[
  430. 2]).copy_to_cpu()
  431. np_segms = self.predictor.get_output_handle(output_names[
  432. 3]).copy_to_cpu()
  433. result = dict(
  434. segm=np_segms,
  435. label=np_label,
  436. score=np_score,
  437. boxes_num=np_boxes_num)
  438. return result
  439. class DetectorPicoDet(Detector):
  440. """
  441. Args:
  442. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  443. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  444. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  445. batch_size (int): size of pre batch in inference
  446. trt_min_shape (int): min shape for dynamic shape in trt
  447. trt_max_shape (int): max shape for dynamic shape in trt
  448. trt_opt_shape (int): opt shape for dynamic shape in trt
  449. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  450. calibration, trt_calib_mode need to set True
  451. cpu_threads (int): cpu threads
  452. enable_mkldnn (bool): whether to turn on MKLDNN
  453. enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16
  454. """
  455. def __init__(
  456. self,
  457. model_dir,
  458. device='CPU',
  459. run_mode='paddle',
  460. batch_size=1,
  461. trt_min_shape=1,
  462. trt_max_shape=1280,
  463. trt_opt_shape=640,
  464. trt_calib_mode=False,
  465. cpu_threads=1,
  466. enable_mkldnn=False,
  467. enable_mkldnn_bfloat16=False,
  468. output_dir='./',
  469. threshold=0.5, ):
  470. super(DetectorPicoDet, self).__init__(
  471. model_dir=model_dir,
  472. device=device,
  473. run_mode=run_mode,
  474. batch_size=batch_size,
  475. trt_min_shape=trt_min_shape,
  476. trt_max_shape=trt_max_shape,
  477. trt_opt_shape=trt_opt_shape,
  478. trt_calib_mode=trt_calib_mode,
  479. cpu_threads=cpu_threads,
  480. enable_mkldnn=enable_mkldnn,
  481. enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
  482. output_dir=output_dir,
  483. threshold=threshold, )
  484. def postprocess(self, inputs, result):
  485. # postprocess output of predictor
  486. np_score_list = result['boxes']
  487. np_boxes_list = result['boxes_num']
  488. postprocessor = PicoDetPostProcess(
  489. inputs['image'].shape[2:],
  490. inputs['im_shape'],
  491. inputs['scale_factor'],
  492. strides=self.pred_config.fpn_stride,
  493. nms_threshold=self.pred_config.nms['nms_threshold'])
  494. np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list)
  495. result = dict(boxes=np_boxes, boxes_num=np_boxes_num)
  496. return result
  497. def predict(self, repeats=1):
  498. '''
  499. Args:
  500. repeats (int): repeat number for prediction
  501. Returns:
  502. result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
  503. matix element:[class, score, x_min, y_min, x_max, y_max]
  504. '''
  505. np_score_list, np_boxes_list = [], []
  506. for i in range(repeats):
  507. self.predictor.run()
  508. np_score_list.clear()
  509. np_boxes_list.clear()
  510. output_names = self.predictor.get_output_names()
  511. num_outs = int(len(output_names) / 2)
  512. for out_idx in range(num_outs):
  513. np_score_list.append(
  514. self.predictor.get_output_handle(output_names[out_idx])
  515. .copy_to_cpu())
  516. np_boxes_list.append(
  517. self.predictor.get_output_handle(output_names[
  518. out_idx + num_outs]).copy_to_cpu())
  519. result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
  520. return result
  521. def create_inputs(imgs, im_info):
  522. """generate input for different model type
  523. Args:
  524. imgs (list(numpy)): list of images (np.ndarray)
  525. im_info (list(dict)): list of image info
  526. Returns:
  527. inputs (dict): input of model
  528. """
  529. inputs = {}
  530. im_shape = []
  531. scale_factor = []
  532. if len(imgs) == 1:
  533. inputs['image'] = np.array((imgs[0], )).astype('float32')
  534. inputs['im_shape'] = np.array(
  535. (im_info[0]['im_shape'], )).astype('float32')
  536. inputs['scale_factor'] = np.array(
  537. (im_info[0]['scale_factor'], )).astype('float32')
  538. return inputs
  539. for e in im_info:
  540. im_shape.append(np.array((e['im_shape'], )).astype('float32'))
  541. scale_factor.append(np.array((e['scale_factor'], )).astype('float32'))
  542. inputs['im_shape'] = np.concatenate(im_shape, axis=0)
  543. inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
  544. imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
  545. max_shape_h = max([e[0] for e in imgs_shape])
  546. max_shape_w = max([e[1] for e in imgs_shape])
  547. padding_imgs = []
  548. for img in imgs:
  549. im_c, im_h, im_w = img.shape[:]
  550. padding_im = np.zeros(
  551. (im_c, max_shape_h, max_shape_w), dtype=np.float32)
  552. padding_im[:, :im_h, :im_w] = img
  553. padding_imgs.append(padding_im)
  554. inputs['image'] = np.stack(padding_imgs, axis=0)
  555. return inputs
  556. class PredictConfig():
  557. """set config of preprocess, postprocess and visualize
  558. Args:
  559. model_dir (str): root path of model.yml
  560. """
  561. def __init__(self, model_dir):
  562. # parsing Yaml config for Preprocess
  563. deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
  564. with open(deploy_file) as f:
  565. yml_conf = yaml.safe_load(f)
  566. self.check_model(yml_conf)
  567. self.arch = yml_conf['arch']
  568. self.preprocess_infos = yml_conf['Preprocess']
  569. self.min_subgraph_size = yml_conf['min_subgraph_size']
  570. self.labels = yml_conf['label_list']
  571. self.mask = False
  572. self.use_dynamic_shape = yml_conf['use_dynamic_shape']
  573. if 'mask' in yml_conf:
  574. self.mask = yml_conf['mask']
  575. self.tracker = None
  576. if 'tracker' in yml_conf:
  577. self.tracker = yml_conf['tracker']
  578. if 'NMS' in yml_conf:
  579. self.nms = yml_conf['NMS']
  580. if 'fpn_stride' in yml_conf:
  581. self.fpn_stride = yml_conf['fpn_stride']
  582. if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
  583. print(
  584. 'The RCNN export model is used for ONNX and it only supports batch_size = 1'
  585. )
  586. self.print_config()
  587. def check_model(self, yml_conf):
  588. """
  589. Raises:
  590. ValueError: loaded model not in supported model type
  591. """
  592. for support_model in SUPPORT_MODELS:
  593. if support_model in yml_conf['arch']:
  594. return True
  595. raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
  596. 'arch'], SUPPORT_MODELS))
  597. def print_config(self):
  598. print('----------- Model Configuration -----------')
  599. print('%s: %s' % ('Model Arch', self.arch))
  600. print('%s: ' % ('Transform Order'))
  601. for op_info in self.preprocess_infos:
  602. print('--%s: %s' % ('transform op', op_info['type']))
  603. print('--------------------------------------------')
  604. def load_predictor(model_dir,
  605. run_mode='paddle',
  606. batch_size=1,
  607. device='CPU',
  608. min_subgraph_size=3,
  609. use_dynamic_shape=False,
  610. trt_min_shape=1,
  611. trt_max_shape=1280,
  612. trt_opt_shape=640,
  613. trt_calib_mode=False,
  614. cpu_threads=1,
  615. enable_mkldnn=False,
  616. enable_mkldnn_bfloat16=False,
  617. delete_shuffle_pass=False):
  618. """set AnalysisConfig, generate AnalysisPredictor
  619. Args:
  620. model_dir (str): root path of __model__ and __params__
  621. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  622. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
  623. use_dynamic_shape (bool): use dynamic shape or not
  624. trt_min_shape (int): min shape for dynamic shape in trt
  625. trt_max_shape (int): max shape for dynamic shape in trt
  626. trt_opt_shape (int): opt shape for dynamic shape in trt
  627. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  628. calibration, trt_calib_mode need to set True
  629. delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
  630. Used by action model.
  631. Returns:
  632. predictor (PaddlePredictor): AnalysisPredictor
  633. Raises:
  634. ValueError: predict by TensorRT need device == 'GPU'.
  635. """
  636. if device != 'GPU' and run_mode != 'paddle':
  637. raise ValueError(
  638. "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
  639. .format(run_mode, device))
  640. config = Config(
  641. os.path.join(model_dir, 'model.pdmodel'),
  642. os.path.join(model_dir, 'model.pdiparams'))
  643. if device == 'GPU':
  644. # initial GPU memory(M), device ID
  645. config.enable_use_gpu(200, 0)
  646. # optimize graph and fuse op
  647. config.switch_ir_optim(True)
  648. elif device == 'XPU':
  649. config.enable_lite_engine()
  650. config.enable_xpu(10 * 1024 * 1024)
  651. else:
  652. config.disable_gpu()
  653. config.set_cpu_math_library_num_threads(cpu_threads)
  654. if enable_mkldnn:
  655. try:
  656. # cache 10 different shapes for mkldnn to avoid memory leak
  657. config.set_mkldnn_cache_capacity(10)
  658. config.enable_mkldnn()
  659. if enable_mkldnn_bfloat16:
  660. config.enable_mkldnn_bfloat16()
  661. except Exception as e:
  662. print(
  663. "The current environment does not support `mkldnn`, so disable mkldnn."
  664. )
  665. pass
  666. precision_map = {
  667. 'trt_int8': Config.Precision.Int8,
  668. 'trt_fp32': Config.Precision.Float32,
  669. 'trt_fp16': Config.Precision.Half
  670. }
  671. if run_mode in precision_map.keys():
  672. config.enable_tensorrt_engine(
  673. workspace_size=(1 << 25) * batch_size,
  674. max_batch_size=batch_size,
  675. min_subgraph_size=min_subgraph_size,
  676. precision_mode=precision_map[run_mode],
  677. use_static=False,
  678. use_calib_mode=trt_calib_mode)
  679. if use_dynamic_shape:
  680. min_input_shape = {
  681. 'image': [batch_size, 3, trt_min_shape, trt_min_shape]
  682. }
  683. max_input_shape = {
  684. 'image': [batch_size, 3, trt_max_shape, trt_max_shape]
  685. }
  686. opt_input_shape = {
  687. 'image': [batch_size, 3, trt_opt_shape, trt_opt_shape]
  688. }
  689. config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
  690. opt_input_shape)
  691. print('trt set dynamic shape done!')
  692. # disable print log when predict
  693. config.disable_glog_info()
  694. # enable shared memory
  695. config.enable_memory_optim()
  696. # disable feed, fetch OP, needed by zero_copy_run
  697. config.switch_use_feed_fetch_ops(False)
  698. if delete_shuffle_pass:
  699. config.delete_pass("shuffle_channel_detect_pass")
  700. predictor = create_predictor(config)
  701. return predictor, config
  702. def get_test_images(infer_dir, infer_img):
  703. """
  704. Get image path list in TEST mode
  705. """
  706. assert infer_img is not None or infer_dir is not None, \
  707. "--image_file or --image_dir should be set"
  708. assert infer_img is None or os.path.isfile(infer_img), \
  709. "{} is not a file".format(infer_img)
  710. assert infer_dir is None or os.path.isdir(infer_dir), \
  711. "{} is not a directory".format(infer_dir)
  712. # infer_img has a higher priority
  713. if infer_img and os.path.isfile(infer_img):
  714. return [infer_img]
  715. images = set()
  716. infer_dir = os.path.abspath(infer_dir)
  717. assert os.path.isdir(infer_dir), \
  718. "infer_dir {} is not a directory".format(infer_dir)
  719. exts = ['jpg', 'jpeg', 'png', 'bmp']
  720. exts += [ext.upper() for ext in exts]
  721. for ext in exts:
  722. images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
  723. images = list(images)
  724. assert len(images) > 0, "no image found in {}".format(infer_dir)
  725. print("Found {} inference images in total.".format(len(images)))
  726. return images
  727. def visualize(image_list, result, labels, output_dir='output/', threshold=0.5):
  728. # visualize the predict result
  729. start_idx = 0
  730. for idx, image_file in enumerate(image_list):
  731. im_bboxes_num = result['boxes_num'][idx]
  732. im_results = {}
  733. if 'boxes' in result:
  734. im_results['boxes'] = result['boxes'][start_idx:start_idx +
  735. im_bboxes_num, :]
  736. if 'masks' in result:
  737. im_results['masks'] = result['masks'][start_idx:start_idx +
  738. im_bboxes_num, :]
  739. if 'segm' in result:
  740. im_results['segm'] = result['segm'][start_idx:start_idx +
  741. im_bboxes_num, :]
  742. if 'label' in result:
  743. im_results['label'] = result['label'][start_idx:start_idx +
  744. im_bboxes_num]
  745. if 'score' in result:
  746. im_results['score'] = result['score'][start_idx:start_idx +
  747. im_bboxes_num]
  748. start_idx += im_bboxes_num
  749. im = visualize_box_mask(
  750. image_file, im_results, labels, threshold=threshold)
  751. img_name = os.path.split(image_file)[-1]
  752. if not os.path.exists(output_dir):
  753. os.makedirs(output_dir)
  754. out_path = os.path.join(output_dir, img_name)
  755. im.save(out_path, quality=95)
  756. print("save result to: " + out_path)
  757. def print_arguments(args):
  758. print('----------- Running Arguments -----------')
  759. for arg, value in sorted(vars(args).items()):
  760. print('%s: %s' % (arg, value))
  761. print('------------------------------------------')
  762. def main():
  763. deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
  764. with open(deploy_file) as f:
  765. yml_conf = yaml.safe_load(f)
  766. arch = yml_conf['arch']
  767. detector_func = 'Detector'
  768. if arch == 'SOLOv2':
  769. detector_func = 'DetectorSOLOv2'
  770. elif arch == 'PicoDet':
  771. detector_func = 'DetectorPicoDet'
  772. detector = eval(detector_func)(
  773. FLAGS.model_dir,
  774. device=FLAGS.device,
  775. run_mode=FLAGS.run_mode,
  776. batch_size=FLAGS.batch_size,
  777. trt_min_shape=FLAGS.trt_min_shape,
  778. trt_max_shape=FLAGS.trt_max_shape,
  779. trt_opt_shape=FLAGS.trt_opt_shape,
  780. trt_calib_mode=FLAGS.trt_calib_mode,
  781. cpu_threads=FLAGS.cpu_threads,
  782. enable_mkldnn=FLAGS.enable_mkldnn,
  783. enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16,
  784. threshold=FLAGS.threshold,
  785. output_dir=FLAGS.output_dir)
  786. # predict from video file or camera video stream
  787. if FLAGS.video_file is not None or FLAGS.camera_id != -1:
  788. detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
  789. else:
  790. # predict from image
  791. if FLAGS.image_dir is None and FLAGS.image_file is not None:
  792. assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
  793. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  794. save_file = os.path.join(FLAGS.output_dir,
  795. 'results.json') if FLAGS.save_results else None
  796. detector.predict_image(
  797. img_list, FLAGS.run_benchmark, repeats=100, save_file=save_file)
  798. if not FLAGS.run_benchmark:
  799. detector.det_times.info(average=True)
  800. else:
  801. mode = FLAGS.run_mode
  802. model_dir = FLAGS.model_dir
  803. model_info = {
  804. 'model_name': model_dir.strip('/').split('/')[-1],
  805. 'precision': mode.split('_')[-1]
  806. }
  807. bench_log(detector, img_list, model_info, name='DET')
  808. if __name__ == '__main__':
  809. paddle.enable_static()
  810. parser = argsparser()
  811. FLAGS = parser.parse_args()
  812. print_arguments(FLAGS)
  813. FLAGS.device = FLAGS.device.upper()
  814. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  815. ], "device should be CPU, GPU or XPU"
  816. assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
  817. assert not (
  818. FLAGS.enable_mkldnn == False and FLAGS.enable_mkldnn_bfloat16 == True
  819. ), 'To enable mkldnn bfloat, please turn on both enable_mkldnn and enable_mkldnn_bfloat16'
  820. main()