data_feed.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import base64
  16. import cv2
  17. import numpy as np
  18. from PIL import Image, ImageDraw
  19. import paddle.fluid as fluid
  20. def create_inputs(im, im_info):
  21. """generate input for different model type
  22. Args:
  23. im (np.ndarray): image (np.ndarray)
  24. im_info (dict): info of image
  25. Returns:
  26. inputs (dict): input of model
  27. """
  28. inputs = {}
  29. inputs['image'] = im
  30. origin_shape = list(im_info['origin_shape'])
  31. resize_shape = list(im_info['resize_shape'])
  32. pad_shape = list(im_info['pad_shape']) if im_info[
  33. 'pad_shape'] is not None else list(im_info['resize_shape'])
  34. scale_x, scale_y = im_info['scale']
  35. scale = scale_x
  36. im_info = np.array([resize_shape + [scale]]).astype('float32')
  37. inputs['im_info'] = im_info
  38. return inputs
  39. def visualize_box_mask(im,
  40. results,
  41. labels=None,
  42. mask_resolution=14,
  43. threshold=0.5):
  44. """
  45. Args:
  46. im (str/np.ndarray): path of image/np.ndarray read by cv2
  47. results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
  48. matix element:[class, score, x_min, y_min, x_max, y_max]
  49. MaskRCNN's results include 'masks': np.ndarray:
  50. shape:[N, class_num, mask_resolution, mask_resolution]
  51. labels (list): labels:['class1', ..., 'classn']
  52. mask_resolution (int): shape of a mask is:[mask_resolution, mask_resolution]
  53. threshold (float): Threshold of score.
  54. Returns:
  55. im (PIL.Image.Image): visualized image
  56. """
  57. if not labels:
  58. labels = [
  59. 'background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
  60. 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant',
  61. 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
  62. 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
  63. 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
  64. 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
  65. 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
  66. 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
  67. 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
  68. 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
  69. 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
  70. 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
  71. 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
  72. 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
  73. ]
  74. if isinstance(im, str):
  75. im = Image.open(im).convert('RGB')
  76. else:
  77. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  78. im = Image.fromarray(im)
  79. if 'masks' in results and 'boxes' in results:
  80. im = draw_mask(
  81. im,
  82. results['boxes'],
  83. results['masks'],
  84. labels,
  85. resolution=mask_resolution)
  86. if 'boxes' in results:
  87. im = draw_box(im, results['boxes'], labels)
  88. if 'segm' in results:
  89. im = draw_segm(
  90. im,
  91. results['segm'],
  92. results['label'],
  93. results['score'],
  94. labels,
  95. threshold=threshold)
  96. return im
  97. def get_color_map_list(num_classes):
  98. """
  99. Args:
  100. num_classes (int): number of class
  101. Returns:
  102. color_map (list): RGB color list
  103. """
  104. color_map = num_classes * [0, 0, 0]
  105. for i in range(0, num_classes):
  106. j = 0
  107. lab = i
  108. while lab:
  109. color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
  110. color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
  111. color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
  112. j += 1
  113. lab >>= 3
  114. color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
  115. return color_map
  116. def expand_boxes(boxes, scale=0.0):
  117. """
  118. Args:
  119. boxes (np.ndarray): shape:[N,4], N:number of box,
  120. matix element:[x_min, y_min, x_max, y_max]
  121. scale (float): scale of boxes
  122. Returns:
  123. boxes_exp (np.ndarray): expanded boxes
  124. """
  125. w_half = (boxes[:, 2] - boxes[:, 0]) * .5
  126. h_half = (boxes[:, 3] - boxes[:, 1]) * .5
  127. x_c = (boxes[:, 2] + boxes[:, 0]) * .5
  128. y_c = (boxes[:, 3] + boxes[:, 1]) * .5
  129. w_half *= scale
  130. h_half *= scale
  131. boxes_exp = np.zeros(boxes.shape)
  132. boxes_exp[:, 0] = x_c - w_half
  133. boxes_exp[:, 2] = x_c + w_half
  134. boxes_exp[:, 1] = y_c - h_half
  135. boxes_exp[:, 3] = y_c + h_half
  136. return boxes_exp
  137. def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
  138. """
  139. Args:
  140. im (PIL.Image.Image): PIL image
  141. np_boxes (np.ndarray): shape:[N,6], N: number of box,
  142. matix element:[class, score, x_min, y_min, x_max, y_max]
  143. np_masks (np.ndarray): shape:[N, class_num, resolution, resolution]
  144. labels (list): labels:['class1', ..., 'classn']
  145. resolution (int): shape of a mask is:[resolution, resolution]
  146. threshold (float): threshold of mask
  147. Returns:
  148. im (PIL.Image.Image): visualized image
  149. """
  150. color_list = get_color_map_list(len(labels))
  151. scale = (resolution + 2.0) / resolution
  152. im_w, im_h = im.size
  153. w_ratio = 0.4
  154. alpha = 0.7
  155. im = np.array(im).astype('float32')
  156. rects = np_boxes[:, 2:]
  157. expand_rects = expand_boxes(rects, scale)
  158. expand_rects = expand_rects.astype(np.int32)
  159. clsid_scores = np_boxes[:, 0:2]
  160. padded_mask = np.zeros((resolution + 2, resolution + 2), dtype=np.float32)
  161. clsid2color = {}
  162. for idx in range(len(np_boxes)):
  163. clsid, score = clsid_scores[idx].tolist()
  164. clsid = int(clsid)
  165. xmin, ymin, xmax, ymax = expand_rects[idx].tolist()
  166. w = xmax - xmin + 1
  167. h = ymax - ymin + 1
  168. w = np.maximum(w, 1)
  169. h = np.maximum(h, 1)
  170. padded_mask[1:-1, 1:-1] = np_masks[idx, int(clsid), :, :]
  171. resized_mask = cv2.resize(padded_mask, (w, h))
  172. resized_mask = np.array(resized_mask > threshold, dtype=np.uint8)
  173. x0 = min(max(xmin, 0), im_w)
  174. x1 = min(max(xmax + 1, 0), im_w)
  175. y0 = min(max(ymin, 0), im_h)
  176. y1 = min(max(ymax + 1, 0), im_h)
  177. im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
  178. im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (
  179. x0 - xmin):(x1 - xmin)]
  180. if clsid not in clsid2color:
  181. clsid2color[clsid] = color_list[clsid]
  182. color_mask = clsid2color[clsid]
  183. for c in range(3):
  184. color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
  185. idx = np.nonzero(im_mask)
  186. color_mask = np.array(color_mask)
  187. im[idx[0], idx[1], :] *= 1.0 - alpha
  188. im[idx[0], idx[1], :] += alpha * color_mask
  189. return Image.fromarray(im.astype('uint8'))
  190. def draw_box(im, np_boxes, labels):
  191. """
  192. Args:
  193. im (PIL.Image.Image): PIL image
  194. np_boxes (np.ndarray): shape:[N,6], N: number of box,
  195. matix element:[class, score, x_min, y_min, x_max, y_max]
  196. labels (list): labels:['class1', ..., 'classn']
  197. Returns:
  198. im (PIL.Image.Image): visualized image
  199. """
  200. draw_thickness = min(im.size) // 320
  201. draw = ImageDraw.Draw(im)
  202. clsid2color = {}
  203. color_list = get_color_map_list(len(labels))
  204. for dt in np_boxes:
  205. clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
  206. xmin, ymin, xmax, ymax = bbox
  207. w = xmax - xmin
  208. h = ymax - ymin
  209. if clsid not in clsid2color:
  210. clsid2color[clsid] = color_list[clsid]
  211. color = tuple(clsid2color[clsid])
  212. # draw bbox
  213. draw.line(
  214. [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
  215. (xmin, ymin)],
  216. width=draw_thickness,
  217. fill=color)
  218. # draw label
  219. text = "{} {:.4f}".format(labels[clsid], score)
  220. tw, th = draw.textsize(text)
  221. draw.rectangle(
  222. [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
  223. draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
  224. return im
  225. def draw_segm(im,
  226. np_segms,
  227. np_label,
  228. np_score,
  229. labels,
  230. threshold=0.5,
  231. alpha=0.7):
  232. """
  233. Draw segmentation on image
  234. """
  235. mask_color_id = 0
  236. w_ratio = .4
  237. color_list = get_color_map_list(len(labels))
  238. im = np.array(im).astype('float32')
  239. clsid2color = {}
  240. np_segms = np_segms.astype(np.uint8)
  241. index = np.where(np_label == 0)[0]
  242. index = np.where(np_score[index] > threshold)[0]
  243. person_segms = np_segms[index]
  244. person_mask = np.sum(person_segms, axis=0)
  245. person_mask[person_mask > 1] = 1
  246. person_mask = np.expand_dims(person_mask, axis=2)
  247. person_mask = np.repeat(person_mask, 3, axis=2)
  248. im = im * person_mask
  249. return Image.fromarray(im.astype('uint8'))
  250. def load_predictor(model_dir,
  251. run_mode='fluid',
  252. batch_size=1,
  253. use_gpu=False,
  254. min_subgraph_size=3):
  255. """set AnalysisConfig, generate AnalysisPredictor
  256. Args:
  257. model_dir (str): root path of __model__ and __params__
  258. use_gpu (bool): whether use gpu
  259. Returns:
  260. predictor (PaddlePredictor): AnalysisPredictor
  261. Raises:
  262. ValueError: predict by TensorRT need use_gpu == True.
  263. """
  264. if not use_gpu and not run_mode == 'fluid':
  265. raise ValueError(
  266. "Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}"
  267. .format(run_mode, use_gpu))
  268. if run_mode == 'trt_int8':
  269. raise ValueError("TensorRT int8 mode is not supported now, "
  270. "please use trt_fp32 or trt_fp16 instead.")
  271. precision_map = {
  272. 'trt_int8': fluid.core.AnalysisConfig.Precision.Int8,
  273. 'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32,
  274. 'trt_fp16': fluid.core.AnalysisConfig.Precision.Half
  275. }
  276. config = fluid.core.AnalysisConfig(
  277. os.path.join(model_dir, '__model__'),
  278. os.path.join(model_dir, '__params__'))
  279. if use_gpu:
  280. # initial GPU memory(M), device ID
  281. config.enable_use_gpu(100, 0)
  282. # optimize graph and fuse op
  283. config.switch_ir_optim(True)
  284. else:
  285. config.disable_gpu()
  286. if run_mode in precision_map.keys():
  287. config.enable_tensorrt_engine(
  288. workspace_size=1 << 10,
  289. max_batch_size=batch_size,
  290. min_subgraph_size=min_subgraph_size,
  291. precision_mode=precision_map[run_mode],
  292. use_static=False,
  293. use_calib_mode=False)
  294. # disable print log when predict
  295. config.disable_glog_info()
  296. # enable shared memory
  297. config.enable_memory_optim()
  298. # disable feed, fetch OP, needed by zero_copy_run
  299. config.switch_use_feed_fetch_ops(False)
  300. predictor = fluid.core.create_paddle_predictor(config)
  301. return predictor
  302. def cv2_to_base64(image):
  303. data = cv2.imencode('.jpg', image)[1]
  304. return base64.b64encode(data.tostring()).decode('utf8')
  305. def base64_to_cv2(b64str):
  306. data = base64.b64decode(b64str.encode('utf8'))
  307. data = np.fromstring(data, np.uint8)
  308. data = cv2.imdecode(data, cv2.IMREAD_COLOR)
  309. return data