attr_infer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import yaml
  16. import glob
  17. from functools import reduce
  18. import cv2
  19. import numpy as np
  20. import math
  21. import paddle
  22. from paddle.inference import Config
  23. from paddle.inference import create_predictor
  24. import sys
  25. # add deploy path of PadleDetection to sys.path
  26. parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
  27. sys.path.insert(0, parent_path)
  28. from benchmark_utils import PaddleInferBenchmark
  29. from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine
  30. from visualize import visualize_attr
  31. from utils import argsparser, Timer, get_current_memory_mb
  32. from infer import Detector, get_test_images, print_arguments, load_predictor
  33. from PIL import Image, ImageDraw, ImageFont
  34. class AttrDetector(Detector):
  35. """
  36. Args:
  37. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  38. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  39. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  40. batch_size (int): size of pre batch in inference
  41. trt_min_shape (int): min shape for dynamic shape in trt
  42. trt_max_shape (int): max shape for dynamic shape in trt
  43. trt_opt_shape (int): opt shape for dynamic shape in trt
  44. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  45. calibration, trt_calib_mode need to set True
  46. cpu_threads (int): cpu threads
  47. enable_mkldnn (bool): whether to open MKLDNN
  48. output_dir (str): The path of output
  49. threshold (float): The threshold of score for visualization
  50. """
  51. def __init__(
  52. self,
  53. model_dir,
  54. device='CPU',
  55. run_mode='paddle',
  56. batch_size=1,
  57. trt_min_shape=1,
  58. trt_max_shape=1280,
  59. trt_opt_shape=640,
  60. trt_calib_mode=False,
  61. cpu_threads=1,
  62. enable_mkldnn=False,
  63. output_dir='output',
  64. threshold=0.5, ):
  65. super(AttrDetector, self).__init__(
  66. model_dir=model_dir,
  67. device=device,
  68. run_mode=run_mode,
  69. batch_size=batch_size,
  70. trt_min_shape=trt_min_shape,
  71. trt_max_shape=trt_max_shape,
  72. trt_opt_shape=trt_opt_shape,
  73. trt_calib_mode=trt_calib_mode,
  74. cpu_threads=cpu_threads,
  75. enable_mkldnn=enable_mkldnn,
  76. output_dir=output_dir,
  77. threshold=threshold, )
  78. def get_label(self):
  79. return self.pred_config.labels
  80. def postprocess(self, inputs, result):
  81. # postprocess output of predictor
  82. im_results = result['output']
  83. labels = self.pred_config.labels
  84. age_list = ['AgeLess18', 'Age18-60', 'AgeOver60']
  85. direct_list = ['Front', 'Side', 'Back']
  86. bag_list = ['HandBag', 'ShoulderBag', 'Backpack']
  87. upper_list = ['UpperStride', 'UpperLogo', 'UpperPlaid', 'UpperSplice']
  88. lower_list = [
  89. 'LowerStripe', 'LowerPattern', 'LongCoat', 'Trousers', 'Shorts',
  90. 'Skirt&Dress'
  91. ]
  92. glasses_threshold = 0.3
  93. hold_threshold = 0.6
  94. batch_res = []
  95. for res in im_results:
  96. res = res.tolist()
  97. label_res = []
  98. # gender
  99. gender = 'Female' if res[22] > self.threshold else 'Male'
  100. label_res.append(gender)
  101. # age
  102. age = age_list[np.argmax(res[19:22])]
  103. label_res.append(age)
  104. # direction
  105. direction = direct_list[np.argmax(res[23:])]
  106. label_res.append(direction)
  107. # glasses
  108. glasses = 'Glasses: '
  109. if res[1] > glasses_threshold:
  110. glasses += 'True'
  111. else:
  112. glasses += 'False'
  113. label_res.append(glasses)
  114. # hat
  115. hat = 'Hat: '
  116. if res[0] > self.threshold:
  117. hat += 'True'
  118. else:
  119. hat += 'False'
  120. label_res.append(hat)
  121. # hold obj
  122. hold_obj = 'HoldObjectsInFront: '
  123. if res[18] > hold_threshold:
  124. hold_obj += 'True'
  125. else:
  126. hold_obj += 'False'
  127. label_res.append(hold_obj)
  128. # bag
  129. bag = bag_list[np.argmax(res[15:18])]
  130. bag_score = res[15 + np.argmax(res[15:18])]
  131. bag_label = bag if bag_score > self.threshold else 'No bag'
  132. label_res.append(bag_label)
  133. # upper
  134. upper_res = res[4:8]
  135. upper_label = 'Upper:'
  136. sleeve = 'LongSleeve' if res[3] > res[2] else 'ShortSleeve'
  137. upper_label += ' {}'.format(sleeve)
  138. for i, r in enumerate(upper_res):
  139. if r > self.threshold:
  140. upper_label += ' {}'.format(upper_list[i])
  141. label_res.append(upper_label)
  142. # lower
  143. lower_res = res[8:14]
  144. lower_label = 'Lower: '
  145. has_lower = False
  146. for i, l in enumerate(lower_res):
  147. if l > self.threshold:
  148. lower_label += ' {}'.format(lower_list[i])
  149. has_lower = True
  150. if not has_lower:
  151. lower_label += ' {}'.format(lower_list[np.argmax(lower_res)])
  152. label_res.append(lower_label)
  153. # shoe
  154. shoe = 'Boots' if res[14] > self.threshold else 'No boots'
  155. label_res.append(shoe)
  156. batch_res.append(label_res)
  157. result = {'output': batch_res}
  158. return result
  159. def predict(self, repeats=1):
  160. '''
  161. Args:
  162. repeats (int): repeats number for prediction
  163. Returns:
  164. result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
  165. matix element:[class, score, x_min, y_min, x_max, y_max]
  166. MaskRCNN's result include 'masks': np.ndarray:
  167. shape: [N, im_h, im_w]
  168. '''
  169. # model prediction
  170. for i in range(repeats):
  171. self.predictor.run()
  172. output_names = self.predictor.get_output_names()
  173. output_tensor = self.predictor.get_output_handle(output_names[0])
  174. np_output = output_tensor.copy_to_cpu()
  175. result = dict(output=np_output)
  176. return result
  177. def predict_image(self,
  178. image_list,
  179. run_benchmark=False,
  180. repeats=1,
  181. visual=True):
  182. batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
  183. results = []
  184. for i in range(batch_loop_cnt):
  185. start_index = i * self.batch_size
  186. end_index = min((i + 1) * self.batch_size, len(image_list))
  187. batch_image_list = image_list[start_index:end_index]
  188. if run_benchmark:
  189. # preprocess
  190. inputs = self.preprocess(batch_image_list) # warmup
  191. self.det_times.preprocess_time_s.start()
  192. inputs = self.preprocess(batch_image_list)
  193. self.det_times.preprocess_time_s.end()
  194. # model prediction
  195. result = self.predict(repeats=repeats) # warmup
  196. self.det_times.inference_time_s.start()
  197. result = self.predict(repeats=repeats)
  198. self.det_times.inference_time_s.end(repeats=repeats)
  199. # postprocess
  200. result_warmup = self.postprocess(inputs, result) # warmup
  201. self.det_times.postprocess_time_s.start()
  202. result = self.postprocess(inputs, result)
  203. self.det_times.postprocess_time_s.end()
  204. self.det_times.img_num += len(batch_image_list)
  205. cm, gm, gu = get_current_memory_mb()
  206. self.cpu_mem += cm
  207. self.gpu_mem += gm
  208. self.gpu_util += gu
  209. else:
  210. # preprocess
  211. self.det_times.preprocess_time_s.start()
  212. inputs = self.preprocess(batch_image_list)
  213. self.det_times.preprocess_time_s.end()
  214. # model prediction
  215. self.det_times.inference_time_s.start()
  216. result = self.predict()
  217. self.det_times.inference_time_s.end()
  218. # postprocess
  219. self.det_times.postprocess_time_s.start()
  220. result = self.postprocess(inputs, result)
  221. self.det_times.postprocess_time_s.end()
  222. self.det_times.img_num += len(batch_image_list)
  223. if visual:
  224. visualize(
  225. batch_image_list, result, output_dir=self.output_dir)
  226. results.append(result)
  227. if visual:
  228. print('Test iter {}'.format(i))
  229. results = self.merge_batch_result(results)
  230. return results
  231. def merge_batch_result(self, batch_result):
  232. if len(batch_result) == 1:
  233. return batch_result[0]
  234. res_key = batch_result[0].keys()
  235. results = {k: [] for k in res_key}
  236. for res in batch_result:
  237. for k, v in res.items():
  238. results[k].extend(v)
  239. return results
  240. def visualize(image_list, batch_res, output_dir='output'):
  241. # visualize the predict result
  242. batch_res = batch_res['output']
  243. for image_file, res in zip(image_list, batch_res):
  244. im = visualize_attr(image_file, [res])
  245. if not os.path.exists(output_dir):
  246. os.makedirs(output_dir)
  247. img_name = os.path.split(image_file)[-1]
  248. out_path = os.path.join(output_dir, img_name)
  249. cv2.imwrite(out_path, im)
  250. print("save result to: " + out_path)
  251. def main():
  252. detector = AttrDetector(
  253. FLAGS.model_dir,
  254. device=FLAGS.device,
  255. run_mode=FLAGS.run_mode,
  256. batch_size=FLAGS.batch_size,
  257. trt_min_shape=FLAGS.trt_min_shape,
  258. trt_max_shape=FLAGS.trt_max_shape,
  259. trt_opt_shape=FLAGS.trt_opt_shape,
  260. trt_calib_mode=FLAGS.trt_calib_mode,
  261. cpu_threads=FLAGS.cpu_threads,
  262. enable_mkldnn=FLAGS.enable_mkldnn,
  263. threshold=FLAGS.threshold,
  264. output_dir=FLAGS.output_dir)
  265. # predict from image
  266. if FLAGS.image_dir is None and FLAGS.image_file is not None:
  267. assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
  268. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  269. detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
  270. if not FLAGS.run_benchmark:
  271. detector.det_times.info(average=True)
  272. else:
  273. mems = {
  274. 'cpu_rss_mb': detector.cpu_mem / len(img_list),
  275. 'gpu_rss_mb': detector.gpu_mem / len(img_list),
  276. 'gpu_util': detector.gpu_util * 100 / len(img_list)
  277. }
  278. perf_info = detector.det_times.report(average=True)
  279. model_dir = FLAGS.model_dir
  280. mode = FLAGS.run_mode
  281. model_info = {
  282. 'model_name': model_dir.strip('/').split('/')[-1],
  283. 'precision': mode.split('_')[-1]
  284. }
  285. data_info = {
  286. 'batch_size': FLAGS.batch_size,
  287. 'shape': "dynamic_shape",
  288. 'data_num': perf_info['img_num']
  289. }
  290. det_log = PaddleInferBenchmark(detector.config, model_info, data_info,
  291. perf_info, mems)
  292. det_log('Attr')
  293. if __name__ == '__main__':
  294. paddle.enable_static()
  295. parser = argsparser()
  296. FLAGS = parser.parse_args()
  297. print_arguments(FLAGS)
  298. FLAGS.device = FLAGS.device.upper()
  299. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  300. ], "device should be CPU, GPU or XPU"
  301. assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
  302. main()