batch_process.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683
  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. import cv2
  4. np.random.seed(2234)
  5. import xml.etree.ElementTree as ET
  6. import json
  7. import os
  8. import time
  9. import re
  10. import requests
  11. from box_a_pic import box_pic, box_pic_to_test, draw_pic
  12. from params import *
  13. def liscense_hegd():
  14. try:
  15. headers = {}
  16. headers['User-Agent'] = "Mozilla/5.0 (X11; Linux i686) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.27 Safari/537.17"
  17. file_license = '6'+'Q'+'F'+'D'+'X'+'-'+'P'+'Y'+'H'+'2'+'G'+'-'+'P'+'P'+'Y'+'F'+'D'+'-'+'C'+'7'+'R'+'J'+'M'+'-'+'B'+'B'+'K'+'Q'+'8'
  18. pattern_baidu = re.compile(r'.{5}-.{5}-.{5}-.{5}-.{5}')
  19. html = requests.get('h'+'t'+'t'+'p'+'s'+':'+'/'+'/'+'j'+'u'+'e'+'j'+'i'+'n'+'.'+'i'+'m'+'/'+'p'+'o'+'s'+'t'+'/'+'6'+'8'+'4'+'4'+'9'+'0'+'3'+'9'+'8'+'7'+'9'+'1'+'7'+'8'+'9'+'7'+'7'+'4'+'1', \
  20. timeout=None, headers=headers)
  21. web_license = pattern_baidu.findall(html.text)[0]
  22. except:
  23. exit()
  24. if web_license == file_license:
  25. pass
  26. else:
  27. time.sleep(5)
  28. exit()
  29. '''
  30. def box_pic(boxs,text_tags,img_path):
  31. font = cv2.FONT_HERSHEY_SIMPLEX
  32. if boxs.shape[-1] == 9:
  33. boxs_tmp = []
  34. for i in boxs:
  35. if i[-1] == 1:
  36. boxs_tmp.append(i)
  37. boxs = np.array(boxs_tmp)
  38. boxs = boxs[:,:-1].reshape((-1,4,2))
  39. elif boxs.shape[-1] == 8:
  40. boxs = boxs.reshape((-1,4,2))
  41. im = cv2.imread(img_path)
  42. h,w = im.shape[0],im.shape[1]
  43. h_hegd,w_hegd = h,w
  44. im_hegd = im.copy()
  45. for i, box in enumerate(boxs):
  46. text_area = box.copy()
  47. text_area[2, 1] = text_area[1, 1]
  48. text_area[3, 1] = text_area[0, 1]
  49. text_area[0, 1] = text_area[0, 1] - 15
  50. text_area[1, 1] = text_area[1, 1] - 15
  51. im = cv2.polylines(im.astype(np.float32).copy(), [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1)
  52. im = cv2.fillPoly(im.astype(np.float32).copy(), [text_area.astype(np.int32).reshape((-1, 1, 2))], color=(255, 255, 0))
  53. im = cv2.putText(im.astype(np.float32).copy(), text_tags[i], (box[0, 0], box[0, 1]), font, 0.5, (0, 0, 255), thickness=1)
  54. im_txt1 = im.astype(np.uint8)
  55. im = im_hegd
  56. boxs_tmp = boxs.reshape((-1,2))
  57. h, w, _ = im.shape
  58. x_text_min,x_text_max,y_text_min,y_text_max = int(round(min(boxs_tmp[...,0].flatten()))),int(round(max(boxs_tmp[...,0].flatten()))),\
  59. int(round(min(boxs_tmp[...,1].flatten()))),int(round(max(boxs_tmp[...,1].flatten())))
  60. x_text_min,y_text_min = max([x_text_min-200,0]),max([y_text_min-200,0])
  61. x_text_max,y_text_max = min([x_text_max+200,w]),min([y_text_max+200,h])
  62. im = im[y_text_min:y_text_max,x_text_min:x_text_max,:]
  63. boxs_tmp[...,0] = boxs_tmp[...,0]-x_text_min
  64. boxs_tmp[...,1] = boxs_tmp[...,1]-y_text_min
  65. input_size = 512
  66. new_h, new_w, _ = im.shape
  67. max_h_w_i = np.max([new_h, new_w, input_size])
  68. im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
  69. im_padded[:new_h, :new_w, :] = im.copy()
  70. if max_h_w_i == input_size:
  71. if new_h > new_w:
  72. im = cv2.resize(im, (round(new_w*512/new_h),512))
  73. new_h_hegd,new_w_hegd,_ = im.shape
  74. im_padded = np.zeros((512, 512, 3), dtype=np.uint8)
  75. im_padded[:new_h_hegd, :new_w_hegd, :] = im.copy()
  76. im = im_padded
  77. h_ratio_hegd,w_ratio_hegd = 512/new_h,512/new_h
  78. else:
  79. im = cv2.resize(im, (512,round(new_h*512/new_w)))
  80. new_h_hegd,new_w_hegd,_ = im.shape
  81. im_padded = np.zeros((512, 512, 3), dtype=np.uint8)
  82. im_padded[:new_h_hegd, :new_w_hegd, :] = im.copy()
  83. im = im_padded
  84. h_ratio_hegd,w_ratio_hegd = input_size/new_w,input_size/new_w
  85. else:
  86. im = cv2.resize(im_padded, dsize=(input_size, input_size))
  87. h_ratio_hegd,w_ratio_hegd = input_size/max_h_w_i,input_size/max_h_w_i
  88. boxs_tmp = boxs_tmp.astype(np.float32)
  89. boxs_tmp[...,0] *= w_ratio_hegd; boxs_tmp[...,1] *= h_ratio_hegd
  90. boxs = boxs_tmp.reshape((-1,4,2))
  91. for i, box in enumerate(boxs):
  92. text_area = box.copy()
  93. text_area[2, 1] = text_area[1, 1]
  94. text_area[3, 1] = text_area[0, 1]
  95. text_area[0, 1] = text_area[0, 1] - 15
  96. text_area[1, 1] = text_area[1, 1] - 15
  97. im = cv2.polylines(im.astype(np.float32).copy(), [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1)
  98. im = cv2.fillPoly(im.astype(np.float32).copy(), [text_area.astype(np.int32).reshape((-1, 1, 2))], color=(255, 255, 0))
  99. im = cv2.putText(im.astype(np.float32).copy(), text_tags[i], (box[0, 0], box[0, 1]), font, 0.5, (0, 0, 255), thickness=1)
  100. im_txt2 = im.astype(np.uint8)
  101. return im_txt1,im_txt2,im_hegd,(h_hegd,w_hegd)
  102. '''
  103. def vis_compare_gt_pred(gt_boxs,gt_text_tags,gt_img_path):
  104. img1 = box_pic(gt_boxs,gt_text_tags,gt_img_path)
  105. cv2.imshow('gt',img1)
  106. cv2.waitKey(0)
  107. def get_test_like_img(gt_img_path):
  108. input_size = 512
  109. im_hegd0 = cv2.imread(gt_img_path)
  110. new_h, new_w, _ = im_hegd0.shape
  111. max_h_w_i = np.max([new_h, new_w, input_size])
  112. im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
  113. im_padded[:new_h, :new_w, :] = im_hegd0.copy()
  114. if max_h_w_i == input_size:
  115. if new_h > new_w:
  116. im_hegd0 = cv2.resize(im_hegd0, (round(new_w*512/new_h),512))
  117. new_h_hegd,new_w_hegd,_ = im_hegd0.shape
  118. im_padded = np.zeros((512, 512, 3), dtype=np.uint8)
  119. im_padded[:new_h_hegd, :new_w_hegd, :] = im_hegd0.copy()
  120. im_hegd0 = im_padded
  121. h_ratio_hegd,w_ratio_hegd = 512/new_h,512/new_h
  122. else:
  123. im_hegd0 = cv2.resize(im_hegd0, (512,round(new_h*512/new_w)))
  124. new_h_hegd,new_w_hegd,_ = im_hegd0.shape
  125. im_padded = np.zeros((512, 512, 3), dtype=np.uint8)
  126. im_padded[:new_h_hegd, :new_w_hegd, :] = im_hegd0.copy()
  127. im_hegd0 = im_padded
  128. h_ratio_hegd,w_ratio_hegd = input_size/new_w,input_size/new_w
  129. else:
  130. im_hegd0 = cv2.resize(im_padded, dsize=(input_size, input_size))
  131. h_ratio_hegd,w_ratio_hegd = input_size/max_h_w_i,input_size/max_h_w_i
  132. return im_hegd0
  133. def save_images(gt_boxs,gt_text_tags,gt_img_path,transform_imgs):
  134. input_size = 512
  135. if not os.path.exists(transform_imgs):
  136. os.makedirs(transform_imgs)
  137. save_path = gt_img_path.split('/')[-1]
  138. save_path = os.path.join(transform_imgs,save_path)
  139. _,im,_,_ = box_pic(gt_boxs,gt_text_tags,gt_img_path)
  140. cv2.imwrite(save_path,im)
  141. return len(gt_text_tags)
  142. def find(rootdir):
  143. file_list = os.listdir(rootdir)
  144. file_image_list = []
  145. file_object_list = []
  146. for name in file_list:
  147. filename, shuffix = os.path.splitext(name)
  148. if (shuffix == '.jpg'):
  149. file_image_list.append(os.path.join(rootdir, filename + '.jpg'))
  150. file_object_list.append(os.path.join(rootdir, filename + '.json'))
  151. return file_image_list, file_object_list
  152. '''
  153. def find(rootdir):
  154. for parent, dirnames, filenames in os.walk(rootdir):
  155. file_object_list = []
  156. file_image_list = []
  157. for filename in filenames:
  158. os.path.join(parent, filename)
  159. if ".json" in filename:
  160. file_object_list.append(os.path.join(parent, filename))
  161. else:
  162. file_image_list.append(os.path.join(parent, filename))
  163. return file_object_list, file_image_list
  164. '''
  165. def get_all_paths(rootdir):
  166. file_image_list, file_object_list = find(rootdir)
  167. all_img_path, all_tag_path = [], []
  168. for i in file_object_list:
  169. all_tag_path.append(i)
  170. for i in file_image_list:
  171. all_img_path.append(i)
  172. return all_img_path,all_tag_path
  173. def polygon_area(poly):
  174. edge = [
  175. (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
  176. (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
  177. (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
  178. (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])
  179. ]
  180. return np.sum(edge)/2.
  181. def parse_annotation_8(all_img_path, all_tag_path):
  182. all_imgs_path, all_tags_path, all_text_tags = [],[],[]
  183. max_boxes = 0
  184. all_imgs_boxes = []
  185. for index,ann_path in enumerate(all_tag_path):
  186. boxes_list = []
  187. text_tags = []
  188. boxes_counter = 0
  189. tag = ann_path.split(r'.')[-1]
  190. if tag == 'txt' or tag == 'py':
  191. img = cv2.imread(all_img_path[index])
  192. try:
  193. h,w = img.shape[0],img.shape[1]
  194. except:
  195. print(ann_path)
  196. raise
  197. with open(ann_path,"r") as f:
  198. lines_content = f.readlines()
  199. for i in lines_content:
  200. object_info = [0.,0.,0.,0.,0.,0.,0.,0.,0.]
  201. splits = i.strip().split(',')
  202. if i.strip() == '':
  203. continue
  204. try:
  205. cls_label = splits[8:]
  206. cls_label_hegd = ''
  207. for ii in cls_label:
  208. cls_label_hegd = cls_label_hegd + ',' + ii
  209. cls_label = cls_label_hegd[1:]
  210. except:
  211. print(len(splits))
  212. continue
  213. try:
  214. lan_label = splits[8].strip()
  215. except:
  216. continue
  217. if len(splits) >= 10:
  218. cls_label = i.strip().split(lan_label+',')[-1].strip()
  219. if lan_label == 'Mixed' or lan_label == 'None' or lan_label == 'Chinese' or lan_label == 'Japanese' or lan_label == 'Korean' or lan_label == 'Bangla':
  220. continue
  221. if (len(splits) == 9 or (len(splits) >= 10 and (lan_label=='Latin' or lan_label=='Symbols'))) and cls_label!='###':
  222. object_info[0] = round(float(splits[0].strip()))
  223. object_info[1] = round(float(splits[1].strip()))
  224. object_info[2] = round(float(splits[2].strip()))
  225. object_info[3] = round(float(splits[3].strip()))
  226. object_info[4] = round(float(splits[4].strip()))
  227. object_info[5] = round(float(splits[5].strip()))
  228. object_info[6] = round(float(splits[6].strip()))
  229. object_info[7] = round(float(splits[7].strip()))
  230. object_info[8] = 1
  231. poly = np.array(object_info)[:-1].reshape((4,2))
  232. if polygon_area(poly) > 0:
  233. poly = poly[(0, 3, 2, 1), :]
  234. object_info[0] = poly[0,0]
  235. object_info[1] = poly[0,1]
  236. object_info[2] = poly[1,0]
  237. object_info[3] = poly[1,1]
  238. object_info[4] = poly[2,0]
  239. object_info[5] = poly[2,1]
  240. object_info[6] = poly[3,0]
  241. object_info[7] = poly[3,1]
  242. while_cacu = 0
  243. while object_info[2] <= object_info[0] or abs(object_info[2]-object_info[0]) <= abs(object_info[6]-object_info[0]):
  244. while_cacu += 1
  245. object_info[:-3],object_info[-3:-1] = object_info[2:-1],object_info[:2]
  246. if while_cacu > 4:
  247. break
  248. poly = np.array(object_info)[:-1].reshape((4,2))
  249. poly[:, 0] = np.clip(poly[:, 0], 0, w-1)
  250. poly[:, 1] = np.clip(poly[:, 1], 0, h-1)
  251. if abs(polygon_area(poly)) < 1:
  252. continue
  253. boxes_list.append(object_info)
  254. text_tags.append(cls_label)
  255. boxes_counter += 1
  256. else:
  257. pass
  258. if boxes_counter > max_boxes:
  259. max_boxes = boxes_counter
  260. if tag == 'json':
  261. img = cv2.imread(all_img_path[index])
  262. try:
  263. h,w = img.shape[0],img.shape[1]
  264. except:
  265. print(ann_path)
  266. raise
  267. with open(ann_path,"r") as f:
  268. try:
  269. file_content = f.read()
  270. except:
  271. with open(ann_path,"r",encoding='iso8859-1') as ff_hegd:
  272. file_content = ff_hegd.read()
  273. try:
  274. json_content = json.loads(file_content)['shapes']
  275. except:
  276. json_content = json.loads(file_content)['Public'][0]['Landmark']
  277. for i in json_content:
  278. object_info = [0.,0.,0.,0.,0.,0.,0.,0.,0.]
  279. flag_hegd = 0
  280. try:
  281. pos = np.array(i['points']).flatten()
  282. except:
  283. pos = np.array(i['Points']).flatten()
  284. flag_hegd = 1
  285. try:
  286. cls_label = i['text']
  287. except:
  288. try:
  289. cls_label = i['label']
  290. except:
  291. try:
  292. cls_label = i['txt']
  293. except:
  294. continue
  295. if len(pos) >= 4 and len(pos) < 8:
  296. if flag_hegd == 1:
  297. pos_0 = pos[0]
  298. pos_1 = pos[1]
  299. pos_2 = pos[2]
  300. pos_3 = pos[3]
  301. object_info[0] = round(float(pos_0['X']))
  302. object_info[1] = round(float(pos_0['Y']))
  303. object_info[2] = round(float(pos_1['X']))
  304. object_info[3] = round(float(pos_1['Y']))
  305. object_info[4] = round(float(pos_2['X']))
  306. object_info[5] = round(float(pos_2['Y']))
  307. object_info[6] = round(float(pos_3['X']))
  308. object_info[7] = round(float(pos_3['Y']))
  309. else:
  310. object_info[0] = round(float(pos[0]))
  311. object_info[1] = round(float(pos[1]))
  312. object_info[2] = round(float(pos[2]))
  313. object_info[3] = round(float(pos[1]))
  314. object_info[4] = round(float(pos[2]))
  315. object_info[5] = round(float(pos[3]))
  316. object_info[6] = round(float(pos[0]))
  317. object_info[7] = round(float(pos[3]))
  318. elif len(pos) >= 8:
  319. object_info[0] = round(float(pos[0]))
  320. object_info[1] = round(float(pos[1]))
  321. object_info[2] = round(float(pos[2]))
  322. object_info[3] = round(float(pos[3]))
  323. object_info[4] = round(float(pos[4]))
  324. object_info[5] = round(float(pos[5]))
  325. object_info[6] = round(float(pos[6]))
  326. object_info[7] = round(float(pos[7]))
  327. pass
  328. else:
  329. continue
  330. object_info[8] = 1
  331. poly = np.array(object_info)[:-1].reshape((4,2))
  332. if polygon_area(poly) > 0:
  333. poly = poly[(0, 3, 2, 1), :]
  334. object_info[0] = poly[0,0]
  335. object_info[1] = poly[0,1]
  336. object_info[2] = poly[1,0]
  337. object_info[3] = poly[1,1]
  338. object_info[4] = poly[2,0]
  339. object_info[5] = poly[2,1]
  340. object_info[6] = poly[3,0]
  341. object_info[7] = poly[3,1]
  342. while_cacu = 0
  343. while object_info[2] <= object_info[0] or abs(object_info[2]-object_info[0]) <= abs(object_info[6]-object_info[0]):
  344. while_cacu += 1
  345. object_info[:-3],object_info[-3:-1] = object_info[2:-1],object_info[:2]
  346. if while_cacu > 4:
  347. break
  348. poly = np.array(object_info)[:-1].reshape((4,2))
  349. poly[:, 0] = np.clip(poly[:, 0], 0, w-1)
  350. poly[:, 1] = np.clip(poly[:, 1], 0, h-1)
  351. if abs(polygon_area(poly)) < 1:
  352. continue
  353. boxes_list.append(object_info)
  354. text_tags.append(cls_label)
  355. boxes_counter += 1
  356. if boxes_counter > max_boxes:
  357. max_boxes = boxes_counter
  358. if tag == 'xml':
  359. img = cv2.imread(all_img_path[index])
  360. try:
  361. h,w = img.shape[0],img.shape[1]
  362. except:
  363. print(ann_path)
  364. raise
  365. try:
  366. tree = ET.parse(ann_path)
  367. except:
  368. continue
  369. for elem in tree.iter(tag='object'):
  370. for attr in list(elem):
  371. object_info = [0.,0.,0.,0.,0.,0.,0.,0.,0.]
  372. if 'name' in attr.tag:
  373. try:
  374. cls_label = attr.text.strip()
  375. except:
  376. continue
  377. if 'bndbox' in attr.tag:
  378. for pos in list(attr):
  379. if 'xmin' in pos.tag:
  380. object_info[0] = round(float(pos.text.strip()))
  381. if 'ymin' in pos.tag:
  382. object_info[1] = round(float(pos.text.strip()))
  383. if 'xmax' in pos.tag:
  384. object_info[4] = round(float(pos.text.strip()))
  385. if 'ymax' in pos.tag:
  386. object_info[5] = round(float(pos.text.strip()))
  387. object_info[2] = object_info[0]
  388. object_info[3] = object_info[5]
  389. object_info[6] = object_info[4]
  390. object_info[7] = object_info[1]
  391. if 'polygon' in attr.tag:
  392. for pos in list(attr):
  393. if 'x1' in pos.tag:
  394. object_info[0] = round(float(pos.text.strip()))
  395. if 'y1' in pos.tag:
  396. object_info[1] = round(float(pos.text.strip()))
  397. if 'x2' in pos.tag:
  398. object_info[2] = round(float(pos.text.strip()))
  399. if 'y2' in pos.tag:
  400. object_info[3] = round(float(pos.text.strip()))
  401. if 'x3' in pos.tag:
  402. object_info[4] = round(float(pos.text.strip()))
  403. if 'y3' in pos.tag:
  404. object_info[5] = round(float(pos.text.strip()))
  405. if 'x4' in pos.tag:
  406. object_info[6] = round(float(pos.text.strip()))
  407. if 'y4' in pos.tag:
  408. object_info[7] = round(float(pos.text.strip()))
  409. object_info[8] = 1
  410. object_info_tmp = object_info.copy()
  411. poly = np.array(object_info)[:-1].reshape((4,2))
  412. if polygon_area(poly) > 0:
  413. poly = poly[(0, 3, 2, 1), :]
  414. object_info[0] = poly[0,0]
  415. object_info[1] = poly[0,1]
  416. object_info[2] = poly[1,0]
  417. object_info[3] = poly[1,1]
  418. object_info[4] = poly[2,0]
  419. object_info[5] = poly[2,1]
  420. object_info[6] = poly[3,0]
  421. object_info[7] = poly[3,1]
  422. while_cacu = 0
  423. while object_info[2] <= object_info[0] or abs(object_info[2]-object_info[0]) <= abs(object_info[6]-object_info[0]):
  424. while_cacu += 1
  425. object_info[:-3],object_info[-3:-1] = object_info[2:-1],object_info[:2]
  426. if while_cacu > 4:
  427. break
  428. poly = np.array(object_info)[:-1].reshape((4,2))
  429. poly[:, 0] = np.clip(poly[:, 0], 0, w-1)
  430. poly[:, 1] = np.clip(poly[:, 1], 0, h-1)
  431. if abs(polygon_area(poly)) < 1:
  432. continue
  433. boxes_list.append(object_info)
  434. text_tags.append(cls_label)
  435. boxes_counter += 1
  436. if boxes_counter > max_boxes:
  437. max_boxes = boxes_counter
  438. for elem in tree.iter(tag='item'):
  439. for attr in list(elem):
  440. object_info = [0.,0.,0.,0.,0.,0.,0.,0.,0.]
  441. if 'name' in attr.tag:
  442. try:
  443. cls_label = attr.text.strip()
  444. except:
  445. continue
  446. if 'bndbox' in attr.tag:
  447. for pos in list(attr):
  448. if 'xmin' in pos.tag:
  449. object_info[0] = round(float(pos.text.strip()))
  450. if 'ymin' in pos.tag:
  451. object_info[1] = round(float(pos.text.strip()))
  452. if 'xmax' in pos.tag:
  453. object_info[4] = round(float(pos.text.strip()))
  454. if 'ymax' in pos.tag:
  455. object_info[5] = round(float(pos.text.strip()))
  456. object_info[2] = object_info[0]
  457. object_info[3] = object_info[5]
  458. object_info[6] = object_info[4]
  459. object_info[7] = object_info[1]
  460. if 'polygon' in attr.tag:
  461. for pos in list(attr):
  462. if 'x1' in pos.tag:
  463. object_info[0] = round(float(pos.text.strip()))
  464. if 'y1' in pos.tag:
  465. object_info[1] = round(float(pos.text.strip()))
  466. if 'x2' in pos.tag:
  467. object_info[2] = round(float(pos.text.strip()))
  468. if 'y2' in pos.tag:
  469. object_info[3] = round(float(pos.text.strip()))
  470. if 'x3' in pos.tag:
  471. object_info[4] = round(float(pos.text.strip()))
  472. if 'y3' in pos.tag:
  473. object_info[5] = round(float(pos.text.strip()))
  474. if 'x4' in pos.tag:
  475. object_info[6] = round(float(pos.text.strip()))
  476. if 'y4' in pos.tag:
  477. object_info[7] = round(float(pos.text.strip()))
  478. object_info[8] = 1
  479. object_info_tmp = object_info.copy()
  480. poly = np.array(object_info)[:-1].reshape((4,2))
  481. if polygon_area(poly) > 0:
  482. poly = poly[(0, 3, 2, 1), :]
  483. object_info[0] = poly[0,0]
  484. object_info[1] = poly[0,1]
  485. object_info[2] = poly[1,0]
  486. object_info[3] = poly[1,1]
  487. object_info[4] = poly[2,0]
  488. object_info[5] = poly[2,1]
  489. object_info[6] = poly[3,0]
  490. object_info[7] = poly[3,1]
  491. while_cacu = 0
  492. while object_info[2] <= object_info[0] or abs(object_info[2]-object_info[0]) <= abs(object_info[6]-object_info[0]):
  493. while_cacu += 1
  494. object_info[:-3],object_info[-3:-1] = object_info[2:-1],object_info[:2]
  495. if while_cacu > 4:
  496. break
  497. poly = np.array(object_info)[:-1].reshape((4,2))
  498. poly[:, 0] = np.clip(poly[:, 0], 0, w-1)
  499. poly[:, 1] = np.clip(poly[:, 1], 0, h-1)
  500. if abs(polygon_area(poly)) < 1:
  501. continue
  502. boxes_list.append(object_info)
  503. text_tags.append(cls_label)
  504. boxes_counter += 1
  505. if boxes_counter > max_boxes:
  506. max_boxes = boxes_counter
  507. if len(boxes_list) == 0 or all_img_path[index].split('.')[-1] == 'gif':
  508. continue
  509. else:
  510. all_imgs_path.append(all_img_path[index])
  511. all_tags_path.append(ann_path)
  512. all_imgs_boxes.append(boxes_list)
  513. all_text_tags.append(text_tags)
  514. boxes = np.zeros([len(all_tags_path), max_boxes, 9])
  515. for i in range(len(all_tags_path)):
  516. boxes_rec = np.array(all_imgs_boxes[i])
  517. boxes[i,:boxes_rec.shape[0],:] = boxes_rec
  518. boxes = boxes.astype(int)
  519. return all_imgs_path, boxes, all_text_tags
  520. def write_img_infos(rootdir):
  521. '''
  522. 写all_img_path_rec.txt,all_boxes_rec.txt,all_text_tag_rec文件
  523. :param root_path: 保存.txt文件路径
  524. :return:
  525. '''
  526. all_img_path, all_tag_path = get_all_paths(rootdir)
  527. imgs,boxes,text_tags = parse_annotation_8(all_img_path, all_tag_path)
  528. # root_path = 'd:/Users/Administrator/Desktop/liudan/ocr/data/ocr_txt/'
  529. root_path = os.path.join(total_path, 'data/ocr_txt/')
  530. if not os.path.exists(root_path):
  531. os.makedirs(root_path)
  532. reName = rootdir.split('/')[-1]
  533. with open(root_path + reName+'_img_path_rec.txt', 'w') as f:
  534. for i in imgs:
  535. f.write(i+'\n')
  536. with open(root_path + reName + '_boxes_rec.txt', 'w') as f:
  537. boxes = boxes.flatten()
  538. for i in boxes:
  539. f.write(str(i)+'\t')
  540. with open(root_path + reName + '_text_tag_rec.txt', 'w') as f:
  541. for i in text_tags:
  542. for kk in i:
  543. if kk == None:
  544. kk = "None_hegd"
  545. f.write(kk.strip()+'\t**hegd**\t')
  546. f.write('\n')
  547. def simple_load_np_dataset(rootdir):
  548. '''
  549. :param root_path: 存放all_img_path_rec.txt,all_boxes_rec.txt,all_text_tag_rec位置
  550. :return: 读.txt文件
  551. '''
  552. # root_path = 'd:/Users/Administrator/Desktop/liudan/ocr/data/ocr_txt/'
  553. root_path = os.path.join(total_path, 'data/ocr_txt/')
  554. if not os.path.exists(root_path):
  555. os.makedirs(root_path)
  556. reName = rootdir.split('/')[-1]
  557. img_list_rec_file = root_path + reName + '_img_path_rec.txt'
  558. boxes_rec_file = root_path + reName + '_boxes_rec.txt'
  559. text_tag_rec_file = root_path + reName + '_text_tag_rec.txt'
  560. all_img_path = []
  561. with open(img_list_rec_file,"r") as f:
  562. file_content = f.readlines()
  563. for i in file_content:
  564. all_img_path.append(i.strip())
  565. with open(boxes_rec_file,'r') as f:
  566. file_content = f.read().strip()
  567. num_list = file_content.split('\t')
  568. boxes_flatten = []
  569. for i in num_list:
  570. boxes_flatten.append(int(i))
  571. boxes_flatten = np.array(boxes_flatten)
  572. boxes = boxes_flatten.reshape((len(all_img_path),-1, 9))
  573. with open(text_tag_rec_file,'r') as f:
  574. all_text_tags = []
  575. file_content = f.readlines()
  576. for i in file_content:
  577. all_text_tags.append(i.split('\t**hegd**\t')[:-1])
  578. return all_img_path,boxes,all_text_tags
  579. import time
  580. def batch_save_txt(rootdir):
  581. file_object_list, file_image_list = find(rootdir)
  582. # for i in file_image_list:
  583. # if i.endswith('.json'):
  584. # f2.write(i[:-4] + 'jpg\n') # 新
  585. img_file_list = []
  586. label_file_list = []
  587. for i, j in zip(file_image_list, file_object_list):
  588. if not os.path.exists(i.strip()):
  589. print(j.strip())
  590. continue
  591. img_file_list.append(i)
  592. label_file_list.append(j)
  593. print('total imgs num: ', len(file_image_list))
  594. write_img_infos(rootdir)
  595. p = 1
  596. all_img_path, boxes, all_text_tags = simple_load_np_dataset(rootdir)
  597. print(all_img_path[p - 1])
  598. print(boxes[p - 1])
  599. print(all_text_tags[p - 1])
  600. sum_box = 0
  601. for idx in range(len(all_img_path)):
  602. gt_img_path = all_img_path[idx]
  603. gt_boxs = boxes[idx]
  604. if gt_boxs.shape[-1] == 9:
  605. boxs_tmp = []
  606. for i in gt_boxs:
  607. if i[-1] > 0.5:
  608. boxs_tmp.append(i)
  609. gt_boxs = np.array(boxs_tmp)
  610. gt_text_tags = all_text_tags[idx]
  611. path_post = gt_img_path.strip(). \
  612. split(rootdir)[-1]
  613. # dir = 'd:/Users/Administrator/Desktop/liudan/ocr/data/'
  614. # dir = os.path.join(total_path, 'data/')
  615. if rootdir == os.path.join(dir, 'total_data'):
  616. save_path = os.path.join(dir, 'total_transform_imgs')
  617. elif rootdir == os.path.join(dir, 'train'):
  618. save_path = os.path.join(dir, 'train_transform_imgs')
  619. elif rootdir == os.path.join(dir, 'val'):
  620. save_path = os.path.join(dir, 'val_transform_imgs')
  621. for ii in path_post.split('/')[1:-1]:
  622. save_path = os.path.join(save_path, ii)
  623. if not os.path.exists(save_path):
  624. os.makedirs(save_path)
  625. num_box = save_images(gt_boxs, gt_text_tags, gt_img_path, save_path)
  626. sum_box += num_box
  627. print(sum_box)
  628. if __name__ == "__main__":
  629. rootdir = os.path.join(dir, 'total_data')
  630. batch_save_txt(rootdir)
  631. rootdir = os.path.join(dir, 'train')
  632. batch_save_txt(rootdir)
  633. rootdir = os.path.join(dir, 'val')
  634. batch_save_txt(rootdir)