eval_file.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import cv2
  2. import time
  3. import math
  4. import os
  5. import numpy as np
  6. import tensorflow.compat.v1 as tf
  7. import nms_locality
  8. from icdar import restore_rectangle, ground_truth_to_word
  9. font = cv2.FONT_HERSHEY_SIMPLEX
  10. test_folder = 'cs/'
  11. #pb_file_path = '/data2/hejinlong/model_saved/FOTS-master/savedmodel/0902/ocr_model-0902-2201.pb'
  12. pb_file_path = '/data2/liudan/ocr/savemode/0801/ocr_model-0801-1920.pb'
  13. def get_images():
  14. files = []
  15. exts = ['jpg', 'png', 'jpeg', 'JPG']
  16. for parent, dirnames, filenames in os.walk(test_folder):
  17. for filename in filenames:
  18. for ext in exts:
  19. if filename.endswith(ext):
  20. files.append(os.path.join(parent, filename))
  21. break
  22. print('Find {} images'.format(len(files)))
  23. return files
  24. def detect(score_map, geo_map, timer, score_map_thresh=0.8, box_thresh=0.1, nms_thres=0.2):
  25. if len(score_map.shape) == 4:
  26. score_map = score_map[0, :, :, 0]
  27. geo_map = geo_map[0, :, :, ]
  28. xy_text = np.argwhere(score_map > score_map_thresh)
  29. xy_text = xy_text[np.argsort(xy_text[:, 0])]
  30. start = time.time()
  31. text_box_restored = restore_rectangle(xy_text[:, ::-1]*4, geo_map[xy_text[:, 0], xy_text[:, 1], :]) # N*4*2
  32. print('{} text boxes before nms'.format(text_box_restored.shape[0]))
  33. boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
  34. boxes[:, :8] = text_box_restored.reshape((-1, 8))
  35. boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
  36. timer['restore'] = time.time() - start
  37. start = time.time()
  38. boxes = nms_locality.nms_locality(boxes.astype('float32'), nms_thres)
  39. timer['nms'] = time.time() - start
  40. if boxes.shape[0] == 0:
  41. return None, timer
  42. for i, box in enumerate(boxes):
  43. mask = np.zeros_like(score_map, dtype=np.uint8)
  44. cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1)
  45. boxes[i, 8] = cv2.mean(score_map, mask)[0]
  46. boxes = boxes[boxes[:, 8] > box_thresh]
  47. return boxes, timer
  48. def get_project_matrix_and_width(text_polyses, target_height=8.0):
  49. project_matrixes = []
  50. box_widths = []
  51. filter_box_masks = []
  52. for i in range(text_polyses.shape[0]):
  53. x1, y1, x2, y2, x3, y3, x4, y4 = text_polyses[i] / 4
  54. rotated_rect = cv2.minAreaRect(np.array([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]))
  55. box_w, box_h = rotated_rect[1][0], rotated_rect[1][1]
  56. if box_w <= box_h:
  57. box_w, box_h = box_h, box_w
  58. mapped_x1, mapped_y1 = (0, 0)
  59. mapped_x4, mapped_y4 = (0, 8)
  60. width_box = math.ceil(8 * box_w / box_h)
  61. width_box = int(min(width_box, 128))
  62. mapped_x2, mapped_y2 = (width_box, 0)
  63. src_pts = np.float32([(x1, y1), (x2, y2), (x4, y4)])
  64. dst_pts = np.float32([(mapped_x1, mapped_y1), (mapped_x2, mapped_y2), (mapped_x4, mapped_y4)])
  65. affine_matrix = cv2.getAffineTransform(dst_pts.astype(np.float32), src_pts.astype(np.float32))
  66. affine_matrix = affine_matrix.flatten()
  67. project_matrixes.append(affine_matrix)
  68. box_widths.append(width_box)
  69. project_matrixes = np.array(project_matrixes)
  70. box_widths = np.array(box_widths)
  71. return project_matrixes, box_widths
  72. def sort_poly(p):
  73. min_axis = np.argmin(np.sum(p, axis=1))
  74. p = p[[min_axis, (min_axis+1)%4, (min_axis+2)%4, (min_axis+3)%4]]
  75. if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
  76. return p
  77. else:
  78. return p[[0, 3, 2, 1]]
  79. from shapely.geometry import Polygon
  80. import copy
  81. def sort_by_area(boxes):
  82. A_lst = []
  83. for box in boxes:
  84. g = Polygon(box[:8].reshape((4, 2)))
  85. A_lst.append(g.area)
  86. B_lst = copy.deepcopy(A_lst)
  87. A_lst.sort(reverse=True)
  88. box_lst = []
  89. idx_lst = []
  90. for i in A_lst:
  91. index = B_lst.index(i)
  92. idx_lst.append(index)
  93. box_lst.append(boxes[index])
  94. return idx_lst,np.array(box_lst)
  95. def sort_box_by_dist(boxes):
  96. A_lst = []
  97. boxes_c = boxes.copy()
  98. boxes_c = boxes_c.reshape((-1,4,2))
  99. boxes_c_x = np.mean(boxes_c[...,0])
  100. boxes_c_y = np.mean(boxes_c[...,1])
  101. for i,j in zip(np.mean(boxes_c[...,0],axis=1),np.mean(boxes_c[...,1],axis=1)):
  102. g = abs(i-boxes_c_x)+abs(j-boxes_c_y)
  103. A_lst.append(g)
  104. B_lst = copy.deepcopy(A_lst)
  105. A_lst.sort(reverse=False)
  106. box_lst = []
  107. idx_lst = []
  108. for i in A_lst:
  109. index = B_lst.index(i)
  110. idx_lst.append(index)
  111. box_lst.append(boxes[index])
  112. return idx_lst,np.array(box_lst)
  113. def test():
  114. try:
  115. os.makedirs('outputs/')
  116. except OSError as e:
  117. if e.errno != 17:
  118. raise
  119. with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
  120. input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images')
  121. input_transform_matrix = tf.placeholder(tf.float32, shape=[None, 6], name='input_transform_matrix')
  122. input_box_mask = []
  123. input_box_mask.append(tf.placeholder(tf.int32, shape=[None], name='input_box_masks_0'))
  124. input_box_widths = tf.placeholder(tf.int32, shape=[None], name='input_box_widths')
  125. with open(pb_file_path, 'rb') as f:
  126. graph_def = tf.GraphDef()
  127. graph_def.ParseFromString(f.read())
  128. output1 = tf.import_graph_def(graph_def,
  129. input_map={'input_images:0': input_images},
  130. return_elements=['feature_fusion/Conv_7/Sigmoid:0','feature_fusion/concat_3:0']
  131. )
  132. output2 = tf.import_graph_def(
  133. graph_def,
  134. input_map={'input_images:0': input_images,\
  135. 'input_transform_matrix:0':input_transform_matrix,\
  136. 'input_box_masks_0:0':input_box_mask[0],\
  137. 'input_box_widths:0':input_box_widths
  138. },
  139. return_elements=['SparseToDense:0']
  140. )
  141. input_size = 512
  142. im_fn_list = get_images()
  143. for im_fn in im_fn_list:
  144. im = cv2.imread(im_fn)[:, :, ::-1]
  145. new_h, new_w, _ = im.shape
  146. h_ratio_hegd,w_ratio_hegd = 1.,1.
  147. max_h_w_i = np.max([new_h, new_w, input_size])
  148. # im_padded = np.ones((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)*127
  149. im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
  150. im_padded[:new_h, :new_w, :] = im.copy()
  151. if max_h_w_i == input_size:
  152. im = im_padded.copy()
  153. # if new_h > new_w:
  154. # im = cv2.resize(im, (round(new_w*512/new_h),512))
  155. # new_h_hegd,new_w_hegd,_ = im.shape
  156. # im_padded = np.zeros((512, 512, 3), dtype=np.uint8)
  157. # im_padded[:new_h_hegd, :new_w_hegd, :] = im.copy()
  158. # im = im_padded
  159. # h_ratio_hegd,w_ratio_hegd = 512/new_h,512/new_h
  160. # else:
  161. # im = cv2.resize(im, (512,round(new_h*512/new_w)))
  162. # new_h_hegd,new_w_hegd,_ = im.shape
  163. # im_padded = np.zeros((512, 512, 3), dtype=np.uint8)
  164. # im_padded[:new_h_hegd, :new_w_hegd, :] = im.copy()
  165. # im = im_padded
  166. # h_ratio_hegd,w_ratio_hegd = input_size/new_w,input_size/new_w
  167. else:
  168. im = cv2.resize(im_padded, dsize=(input_size, input_size))
  169. h_ratio_hegd,w_ratio_hegd = input_size/max_h_w_i,input_size/max_h_w_i
  170. start_time = time.time()
  171. timer = {'net': 0, 'restore': 0, 'nms': 0}
  172. start = time.time()
  173. score, geometry = sess.run(output1, feed_dict={input_images: [im]})
  174. boxes, timer = detect(score_map=score, geo_map=geometry, timer=timer)
  175. res_file = 'outputs/001.txt'
  176. if boxes is not None and boxes.shape[0] != 0:
  177. input_roi_boxes = boxes[:, :8].reshape(-1, 8)
  178. boxes_masks = [int(0)] * input_roi_boxes.shape[0]
  179. transform_matrixes, box_widths = get_project_matrix_and_width(input_roi_boxes)
  180. try:
  181. recog_decode = sess.run(output2, feed_dict={input_images: [im], \
  182. input_transform_matrix: transform_matrixes,\
  183. input_box_mask[0]: boxes_masks,\
  184. input_box_widths: box_widths})[0]
  185. except:
  186. with open(res_file, 'w') as f:
  187. f.write('')
  188. continue
  189. timer['net'] = time.time() - start
  190. boxes = boxes[:, :8].reshape((-1, 4, 2))
  191. if recog_decode.shape[0] != boxes.shape[0]:
  192. print("detection and recognition result are not equal!")
  193. exit(-1)
  194. # idx_lst,boxes = sort_box_by_dist(boxes)
  195. # recog_decode_lst = []
  196. # for i in idx_lst:
  197. # recog_decode_lst.append(recog_decode[i])
  198. # recog_decode = recog_decode_lst
  199. # if im_fn.split('/')[-2] == '刷新表':
  200. # boxes = boxes[:2,...]
  201. # recog_decode = recog_decode[:2]
  202. # if im_fn.split('/')[-2] in ['绿色表' ,'黑色表']:
  203. # boxes = boxes[:4,...]
  204. # recog_decode = recog_decode[:4]
  205. # if len(boxes) < 4:
  206. # boxes = []
  207. with open(res_file, 'w') as f:
  208. text_tags = []
  209. boxes_hegd = []
  210. for i, box in enumerate(boxes):
  211. box = sort_poly(box.astype(np.int32))
  212. box[...,0],box[...,1] = box[...,0]/h_ratio_hegd,box[...,1]/w_ratio_hegd
  213. if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3]-box[0]) < 5:
  214. continue
  215. if np.any(box[...,0]>new_w*1.1) or np.any(box[...,1]>new_h*1.1) or np.any(box[...,0]<-new_w*0.1) or np.any(box[...,1]<-new_h*0.1):
  216. continue
  217. recognition_result = ground_truth_to_word(recog_decode[i])
  218. for ii in range(box.shape[0]):
  219. for jj in range(box.shape[1]):
  220. box[ii,jj] = round(box[ii,jj])
  221. box = box.astype(np.int32)
  222. text_tags.append(recognition_result)
  223. boxes_hegd.append(box)
  224. f.write('{},{},{},{},{},{},{},{},{}\r\n'.format(
  225. box[0, 0], box[0, 1], box[1, 0], box[1, 1],\
  226. box[2, 0], box[2, 1], box[3, 0], box[3, 1],\
  227. recognition_result
  228. ))
  229. from box_a_pic import box_pic
  230. boxes_hegd = np.array(boxes_hegd)
  231. if len(boxes_hegd) == 0:
  232. im_txt1 = cv2.imread(im_fn)
  233. else:
  234. im_txt1,_,_ = box_pic(boxes_hegd,text_tags,im_fn)
  235. im_fn = im_fn.split(test_folder)[-1] #[1:]
  236. img_path = os.path.join('outputs', im_fn)
  237. print(img_path)
  238. dir_path = img_path.split(os.path.basename(img_path))[0]
  239. if not os.path.exists(dir_path):
  240. os.makedirs(dir_path)
  241. cv2.imwrite(img_path, im_txt1)
  242. else:
  243. timer['net'] = time.time() - start
  244. f = open(res_file, "w")
  245. f.close()
  246. print('{} : net {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms'.format(
  247. im_fn, timer['net']*1000, timer['restore']*1000, timer['nms']*1000))
  248. duration = time.time() - start_time
  249. print('[timing] {}'.format(duration))
  250. if __name__ == '__main__':
  251. test()