preprocess.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from PIL import Image
  15. import cv2
  16. import numpy as np
  17. # Global dictionary
  18. RESIZE_SCALE_SET = {
  19. 'RCNN',
  20. 'RetinaNet',
  21. 'FCOS',
  22. 'SOLOv2',
  23. }
  24. def decode_image(im_file, im_info):
  25. """read rgb image
  26. Args:
  27. im_file (str/np.ndarray): path of image/ np.ndarray read by cv2
  28. im_info (dict): info of image
  29. Returns:
  30. im (np.ndarray): processed image (np.ndarray)
  31. im_info (dict): info of processed image
  32. """
  33. if isinstance(im_file, str):
  34. with open(im_file, 'rb') as f:
  35. im_read = f.read()
  36. data = np.frombuffer(im_read, dtype='uint8')
  37. im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
  38. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  39. im_info['origin_shape'] = im.shape[:2]
  40. im_info['resize_shape'] = im.shape[:2]
  41. else:
  42. im = im_file
  43. im_info['origin_shape'] = im.shape[:2]
  44. im_info['resize_shape'] = im.shape[:2]
  45. return im, im_info
  46. class Resize(object):
  47. """resize image by target_size and max_size
  48. Args:
  49. arch (str): model type
  50. target_size (int): the target size of image
  51. max_size (int): the max size of image
  52. use_cv2 (bool): whether us cv2
  53. image_shape (list): input shape of model
  54. interp (int): method of resize
  55. """
  56. def __init__(self,
  57. arch,
  58. target_size,
  59. max_size,
  60. use_cv2=True,
  61. image_shape=None,
  62. interp=cv2.INTER_LINEAR,
  63. resize_box=False):
  64. self.target_size = target_size
  65. self.max_size = max_size
  66. self.image_shape = image_shape
  67. self.arch = arch
  68. self.use_cv2 = use_cv2
  69. self.interp = interp
  70. def __call__(self, im, im_info):
  71. """
  72. Args:
  73. im (np.ndarray): image (np.ndarray)
  74. im_info (dict): info of image
  75. Returns:
  76. im (np.ndarray): processed image (np.ndarray)
  77. im_info (dict): info of processed image
  78. """
  79. im_channel = im.shape[2]
  80. im_scale_x, im_scale_y = self.generate_scale(im)
  81. im_info['resize_shape'] = [
  82. im_scale_x * float(im.shape[0]), im_scale_y * float(im.shape[1])
  83. ]
  84. if self.use_cv2:
  85. im = cv2.resize(
  86. im,
  87. None,
  88. None,
  89. fx=im_scale_x,
  90. fy=im_scale_y,
  91. interpolation=self.interp)
  92. else:
  93. resize_w = int(im_scale_x * float(im.shape[1]))
  94. resize_h = int(im_scale_y * float(im.shape[0]))
  95. if self.max_size != 0:
  96. raise TypeError(
  97. 'If you set max_size to cap the maximum size of image,'
  98. 'please set use_cv2 to True to resize the image.')
  99. im = im.astype('uint8')
  100. im = Image.fromarray(im)
  101. im = im.resize((int(resize_w), int(resize_h)), self.interp)
  102. im = np.array(im)
  103. # padding im when image_shape fixed by infer_cfg.yml
  104. if self.max_size != 0 and self.image_shape is not None:
  105. padding_im = np.zeros(
  106. (self.max_size, self.max_size, im_channel), dtype=np.float32)
  107. im_h, im_w = im.shape[:2]
  108. padding_im[:im_h, :im_w, :] = im
  109. im = padding_im
  110. im_info['scale'] = [im_scale_x, im_scale_y]
  111. return im, im_info
  112. def generate_scale(self, im):
  113. """
  114. Args:
  115. im (np.ndarray): image (np.ndarray)
  116. Returns:
  117. im_scale_x: the resize ratio of X
  118. im_scale_y: the resize ratio of Y
  119. """
  120. origin_shape = im.shape[:2]
  121. im_c = im.shape[2]
  122. if self.max_size != 0 and self.arch in RESIZE_SCALE_SET:
  123. im_size_min = np.min(origin_shape[0:2])
  124. im_size_max = np.max(origin_shape[0:2])
  125. im_scale = float(self.target_size) / float(im_size_min)
  126. if np.round(im_scale * im_size_max) > self.max_size:
  127. im_scale = float(self.max_size) / float(im_size_max)
  128. im_scale_x = im_scale
  129. im_scale_y = im_scale
  130. else:
  131. im_scale_x = float(self.target_size) / float(origin_shape[1])
  132. im_scale_y = float(self.target_size) / float(origin_shape[0])
  133. return im_scale_x, im_scale_y
  134. class Normalize(object):
  135. """normalize image
  136. Args:
  137. mean (list): im - mean
  138. std (list): im / std
  139. is_scale (bool): whether need im / 255
  140. is_channel_first (bool): if True: image shape is CHW, else: HWC
  141. """
  142. def __init__(self, mean, std, is_scale=True, is_channel_first=False):
  143. self.mean = mean
  144. self.std = std
  145. self.is_scale = is_scale
  146. self.is_channel_first = is_channel_first
  147. def __call__(self, im, im_info):
  148. """
  149. Args:
  150. im (np.ndarray): image (np.ndarray)
  151. im_info (dict): info of image
  152. Returns:
  153. im (np.ndarray): processed image (np.ndarray)
  154. im_info (dict): info of processed image
  155. """
  156. im = im.astype(np.float32, copy=False)
  157. if self.is_channel_first:
  158. mean = np.array(self.mean)[:, np.newaxis, np.newaxis]
  159. std = np.array(self.std)[:, np.newaxis, np.newaxis]
  160. else:
  161. mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
  162. std = np.array(self.std)[np.newaxis, np.newaxis, :]
  163. if self.is_scale:
  164. im = im / 255.0
  165. im -= mean
  166. im /= std
  167. return im, im_info
  168. class Permute(object):
  169. """permute image
  170. Args:
  171. to_bgr (bool): whether convert RGB to BGR
  172. channel_first (bool): whether convert HWC to CHW
  173. """
  174. def __init__(self, to_bgr=False, channel_first=True):
  175. self.to_bgr = to_bgr
  176. self.channel_first = channel_first
  177. def __call__(self, im, im_info):
  178. """
  179. Args:
  180. im (np.ndarray): image (np.ndarray)
  181. im_info (dict): info of image
  182. Returns:
  183. im (np.ndarray): processed image (np.ndarray)
  184. im_info (dict): info of processed image
  185. """
  186. if self.channel_first:
  187. im = im.transpose((2, 0, 1)).copy()
  188. if self.to_bgr:
  189. im = im[[2, 1, 0], :, :]
  190. return im, im_info
  191. class PadStride(object):
  192. """ padding image for model with FPN
  193. Args:
  194. stride (bool): model with FPN need image shape % stride == 0
  195. """
  196. def __init__(self, stride=0):
  197. self.coarsest_stride = stride
  198. def __call__(self, im, im_info):
  199. """
  200. Args:
  201. im (np.ndarray): image (np.ndarray)
  202. im_info (dict): info of image
  203. Returns:
  204. im (np.ndarray): processed image (np.ndarray)
  205. im_info (dict): info of processed image
  206. """
  207. coarsest_stride = self.coarsest_stride
  208. if coarsest_stride == 0:
  209. return im
  210. im_c, im_h, im_w = im.shape
  211. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  212. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  213. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  214. padding_im[:, :im_h, :im_w] = im
  215. im_info['pad_shape'] = padding_im.shape[1:]
  216. return padding_im, im_info
  217. def preprocess(im, preprocess_ops):
  218. # process image by preprocess_ops
  219. im_info = {
  220. 'scale': [1., 1.],
  221. 'origin_shape': None,
  222. 'resize_shape': None,
  223. 'pad_shape': None,
  224. }
  225. im, im_info = decode_image(im, im_info)
  226. for operator in preprocess_ops:
  227. im, im_info = operator(im, im_info)
  228. im = np.array((im, )).astype('float32')
  229. return im, im_info