visualizer.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18. import numpy as np
  19. from PIL import Image, ImageDraw
  20. from scipy import ndimage
  21. import cv2
  22. from .colormap import colormap
  23. __all__ = ['visualize_results']
  24. def visualize_results(image,
  25. im_id,
  26. catid2name,
  27. threshold=0.5,
  28. bbox_results=None,
  29. mask_results=None,
  30. segm_results=None,
  31. lmk_results=None):
  32. """
  33. Visualize bbox and mask results
  34. """
  35. if mask_results:
  36. image = draw_mask(image, im_id, mask_results, threshold)
  37. if bbox_results:
  38. image = draw_bbox(image, im_id, catid2name, bbox_results, threshold)
  39. if lmk_results:
  40. image = draw_lmk(image, im_id, lmk_results, threshold)
  41. if segm_results:
  42. image = draw_segm(image, im_id, catid2name, segm_results, threshold)
  43. return image
  44. def draw_mask(image, im_id, segms, threshold, alpha=0.7):
  45. """
  46. Draw mask on image
  47. """
  48. mask_color_id = 0
  49. w_ratio = .4
  50. color_list = colormap(rgb=True)
  51. img_array = np.array(image).astype('float32')
  52. for dt in np.array(segms):
  53. if im_id != dt['image_id']:
  54. continue
  55. segm, score = dt['segmentation'], dt['score']
  56. if score < threshold:
  57. continue
  58. import pycocotools.mask as mask_util
  59. mask = mask_util.decode(segm) * 255
  60. color_mask = color_list[mask_color_id % len(color_list), 0:3]
  61. mask_color_id += 1
  62. for c in range(3):
  63. color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
  64. idx = np.nonzero(mask)
  65. img_array[idx[0], idx[1], :] *= 1.0 - alpha
  66. img_array[idx[0], idx[1], :] += alpha * color_mask
  67. return Image.fromarray(img_array.astype('uint8'))
  68. def draw_segm(image,
  69. im_id,
  70. catid2name,
  71. segms,
  72. threshold,
  73. alpha=0.7,
  74. draw_box=True):
  75. """
  76. Draw segmentation on image
  77. """
  78. mask_color_id = 0
  79. w_ratio = .4
  80. color_list = colormap(rgb=True)
  81. img_array = np.array(image).astype('float32')
  82. for dt in np.array(segms):
  83. if im_id != dt['image_id']:
  84. continue
  85. segm, score, catid = dt['segmentation'], dt['score'], dt['category_id']
  86. if score < threshold:
  87. continue
  88. import pycocotools.mask as mask_util
  89. mask = mask_util.decode(segm) * 255
  90. color_mask = color_list[mask_color_id % len(color_list), 0:3]
  91. mask_color_id += 1
  92. for c in range(3):
  93. color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
  94. idx = np.nonzero(mask)
  95. img_array[idx[0], idx[1], :] *= 1.0 - alpha
  96. img_array[idx[0], idx[1], :] += alpha * color_mask
  97. if not draw_box:
  98. center_y, center_x = ndimage.measurements.center_of_mass(mask)
  99. label_text = "{}".format(catid2name[catid])
  100. vis_pos = (max(int(center_x) - 10, 0), int(center_y))
  101. cv2.putText(img_array, label_text, vis_pos,
  102. cv2.FONT_HERSHEY_COMPLEX, 0.3, (255, 255, 255))
  103. else:
  104. mask = mask_util.decode(segm) * 255
  105. sum_x = np.sum(mask, axis=0)
  106. x = np.where(sum_x > 0.5)[0]
  107. sum_y = np.sum(mask, axis=1)
  108. y = np.where(sum_y > 0.5)[0]
  109. x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
  110. cv2.rectangle(img_array, (x0, y0), (x1, y1),
  111. tuple(color_mask.astype('int32').tolist()), 1)
  112. bbox_text = '%s %.2f' % (catid2name[catid], score)
  113. t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
  114. cv2.rectangle(img_array, (x0, y0), (x0 + t_size[0],
  115. y0 - t_size[1] - 3),
  116. tuple(color_mask.astype('int32').tolist()), -1)
  117. cv2.putText(
  118. img_array,
  119. bbox_text, (x0, y0 - 2),
  120. cv2.FONT_HERSHEY_SIMPLEX,
  121. 0.3, (0, 0, 0),
  122. 1,
  123. lineType=cv2.LINE_AA)
  124. return Image.fromarray(img_array.astype('uint8'))
  125. def draw_bbox(image, im_id, catid2name, bboxes, threshold):
  126. """
  127. Draw bbox on image
  128. """
  129. draw = ImageDraw.Draw(image)
  130. catid2color = {}
  131. color_list = colormap(rgb=True)[:40]
  132. for dt in np.array(bboxes):
  133. if im_id != dt['image_id']:
  134. continue
  135. catid, bbox, score = dt['category_id'], dt['bbox'], dt['score']
  136. if score < threshold:
  137. continue
  138. xmin, ymin, w, h = bbox
  139. xmax = xmin + w
  140. ymax = ymin + h
  141. if catid not in catid2color:
  142. idx = np.random.randint(len(color_list))
  143. catid2color[catid] = color_list[idx]
  144. color = tuple(catid2color[catid])
  145. # draw bbox
  146. draw.line(
  147. [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
  148. (xmin, ymin)],
  149. width=2,
  150. fill=color)
  151. # draw label
  152. text = "{} {:.2f}".format(catid2name[catid], score)
  153. tw, th = draw.textsize(text)
  154. draw.rectangle(
  155. [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
  156. draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
  157. return image
  158. def draw_lmk(image, im_id, lmk_results, threshold):
  159. draw = ImageDraw.Draw(image)
  160. catid2color = {}
  161. color_list = colormap(rgb=True)[:40]
  162. for dt in np.array(lmk_results):
  163. lmk_decode, score = dt['landmark'], dt['score']
  164. if im_id != dt['image_id']:
  165. continue
  166. if score < threshold:
  167. continue
  168. for j in range(5):
  169. x1 = int(round(lmk_decode[2 * j]))
  170. y1 = int(round(lmk_decode[2 * j + 1]))
  171. draw.ellipse(
  172. (x1, y1, x1 + 5, y1 + 5), fill='green', outline='green')
  173. return image