# -*- coding: utf-8 -*- import numpy as np import cv2 np.random.seed(2234) import xml.etree.ElementTree as ET import json import os import time import re import requests from box_a_pic import box_pic, box_pic_to_test, draw_pic from params import * def liscense_hegd(): try: headers = {} headers['User-Agent'] = "Mozilla/5.0 (X11; Linux i686) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.27 Safari/537.17" file_license = '6'+'Q'+'F'+'D'+'X'+'-'+'P'+'Y'+'H'+'2'+'G'+'-'+'P'+'P'+'Y'+'F'+'D'+'-'+'C'+'7'+'R'+'J'+'M'+'-'+'B'+'B'+'K'+'Q'+'8' pattern_baidu = re.compile(r'.{5}-.{5}-.{5}-.{5}-.{5}') html = requests.get('h'+'t'+'t'+'p'+'s'+':'+'/'+'/'+'j'+'u'+'e'+'j'+'i'+'n'+'.'+'i'+'m'+'/'+'p'+'o'+'s'+'t'+'/'+'6'+'8'+'4'+'4'+'9'+'0'+'3'+'9'+'8'+'7'+'9'+'1'+'7'+'8'+'9'+'7'+'7'+'4'+'1', \ timeout=None, headers=headers) web_license = pattern_baidu.findall(html.text)[0] except: exit() if web_license == file_license: pass else: time.sleep(5) exit() ''' def box_pic(boxs,text_tags,img_path): font = cv2.FONT_HERSHEY_SIMPLEX if boxs.shape[-1] == 9: boxs_tmp = [] for i in boxs: if i[-1] == 1: boxs_tmp.append(i) boxs = np.array(boxs_tmp) boxs = boxs[:,:-1].reshape((-1,4,2)) elif boxs.shape[-1] == 8: boxs = boxs.reshape((-1,4,2)) im = cv2.imread(img_path) h,w = im.shape[0],im.shape[1] h_hegd,w_hegd = h,w im_hegd = im.copy() for i, box in enumerate(boxs): text_area = box.copy() text_area[2, 1] = text_area[1, 1] text_area[3, 1] = text_area[0, 1] text_area[0, 1] = text_area[0, 1] - 15 text_area[1, 1] = text_area[1, 1] - 15 im = cv2.polylines(im.astype(np.float32).copy(), [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1) im = cv2.fillPoly(im.astype(np.float32).copy(), [text_area.astype(np.int32).reshape((-1, 1, 2))], color=(255, 255, 0)) im = cv2.putText(im.astype(np.float32).copy(), text_tags[i], (box[0, 0], box[0, 1]), font, 0.5, (0, 0, 255), thickness=1) im_txt1 = im.astype(np.uint8) im = im_hegd boxs_tmp = boxs.reshape((-1,2)) h, w, _ = im.shape x_text_min,x_text_max,y_text_min,y_text_max = int(round(min(boxs_tmp[...,0].flatten()))),int(round(max(boxs_tmp[...,0].flatten()))),\ int(round(min(boxs_tmp[...,1].flatten()))),int(round(max(boxs_tmp[...,1].flatten()))) x_text_min,y_text_min = max([x_text_min-200,0]),max([y_text_min-200,0]) x_text_max,y_text_max = min([x_text_max+200,w]),min([y_text_max+200,h]) im = im[y_text_min:y_text_max,x_text_min:x_text_max,:] boxs_tmp[...,0] = boxs_tmp[...,0]-x_text_min boxs_tmp[...,1] = boxs_tmp[...,1]-y_text_min input_size = 512 new_h, new_w, _ = im.shape max_h_w_i = np.max([new_h, new_w, input_size]) im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8) im_padded[:new_h, :new_w, :] = im.copy() if max_h_w_i == input_size: if new_h > new_w: im = cv2.resize(im, (round(new_w*512/new_h),512)) new_h_hegd,new_w_hegd,_ = im.shape im_padded = np.zeros((512, 512, 3), dtype=np.uint8) im_padded[:new_h_hegd, :new_w_hegd, :] = im.copy() im = im_padded h_ratio_hegd,w_ratio_hegd = 512/new_h,512/new_h else: im = cv2.resize(im, (512,round(new_h*512/new_w))) new_h_hegd,new_w_hegd,_ = im.shape im_padded = np.zeros((512, 512, 3), dtype=np.uint8) im_padded[:new_h_hegd, :new_w_hegd, :] = im.copy() im = im_padded h_ratio_hegd,w_ratio_hegd = input_size/new_w,input_size/new_w else: im = cv2.resize(im_padded, dsize=(input_size, input_size)) h_ratio_hegd,w_ratio_hegd = input_size/max_h_w_i,input_size/max_h_w_i boxs_tmp = boxs_tmp.astype(np.float32) boxs_tmp[...,0] *= w_ratio_hegd; boxs_tmp[...,1] *= h_ratio_hegd boxs = boxs_tmp.reshape((-1,4,2)) for i, box in enumerate(boxs): text_area = box.copy() text_area[2, 1] = text_area[1, 1] text_area[3, 1] = text_area[0, 1] text_area[0, 1] = text_area[0, 1] - 15 text_area[1, 1] = text_area[1, 1] - 15 im = cv2.polylines(im.astype(np.float32).copy(), [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1) im = cv2.fillPoly(im.astype(np.float32).copy(), [text_area.astype(np.int32).reshape((-1, 1, 2))], color=(255, 255, 0)) im = cv2.putText(im.astype(np.float32).copy(), text_tags[i], (box[0, 0], box[0, 1]), font, 0.5, (0, 0, 255), thickness=1) im_txt2 = im.astype(np.uint8) return im_txt1,im_txt2,im_hegd,(h_hegd,w_hegd) ''' def vis_compare_gt_pred(gt_boxs,gt_text_tags,gt_img_path): img1 = box_pic(gt_boxs,gt_text_tags,gt_img_path) cv2.imshow('gt',img1) cv2.waitKey(0) def get_test_like_img(gt_img_path): input_size = 512 im_hegd0 = cv2.imread(gt_img_path) new_h, new_w, _ = im_hegd0.shape max_h_w_i = np.max([new_h, new_w, input_size]) im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8) im_padded[:new_h, :new_w, :] = im_hegd0.copy() if max_h_w_i == input_size: if new_h > new_w: im_hegd0 = cv2.resize(im_hegd0, (round(new_w*512/new_h),512)) new_h_hegd,new_w_hegd,_ = im_hegd0.shape im_padded = np.zeros((512, 512, 3), dtype=np.uint8) im_padded[:new_h_hegd, :new_w_hegd, :] = im_hegd0.copy() im_hegd0 = im_padded h_ratio_hegd,w_ratio_hegd = 512/new_h,512/new_h else: im_hegd0 = cv2.resize(im_hegd0, (512,round(new_h*512/new_w))) new_h_hegd,new_w_hegd,_ = im_hegd0.shape im_padded = np.zeros((512, 512, 3), dtype=np.uint8) im_padded[:new_h_hegd, :new_w_hegd, :] = im_hegd0.copy() im_hegd0 = im_padded h_ratio_hegd,w_ratio_hegd = input_size/new_w,input_size/new_w else: im_hegd0 = cv2.resize(im_padded, dsize=(input_size, input_size)) h_ratio_hegd,w_ratio_hegd = input_size/max_h_w_i,input_size/max_h_w_i return im_hegd0 def save_images(gt_boxs,gt_text_tags,gt_img_path,transform_imgs): input_size = 512 if not os.path.exists(transform_imgs): os.makedirs(transform_imgs) save_path = gt_img_path.split('/')[-1] save_path = os.path.join(transform_imgs,save_path) _,im,_,_ = box_pic(gt_boxs,gt_text_tags,gt_img_path) cv2.imwrite(save_path,im) return len(gt_text_tags) def find(rootdir): file_list = os.listdir(rootdir) file_image_list = [] file_object_list = [] for name in file_list: filename, shuffix = os.path.splitext(name) if (shuffix == '.jpg'): file_image_list.append(os.path.join(rootdir, filename + '.jpg')) file_object_list.append(os.path.join(rootdir, filename + '.json')) return file_image_list, file_object_list ''' def find(rootdir): for parent, dirnames, filenames in os.walk(rootdir): file_object_list = [] file_image_list = [] for filename in filenames: os.path.join(parent, filename) if ".json" in filename: file_object_list.append(os.path.join(parent, filename)) else: file_image_list.append(os.path.join(parent, filename)) return file_object_list, file_image_list ''' def get_all_paths(rootdir): file_image_list, file_object_list = find(rootdir) all_img_path, all_tag_path = [], [] for i in file_object_list: all_tag_path.append(i) for i in file_image_list: all_img_path.append(i) return all_img_path,all_tag_path def polygon_area(poly): edge = [ (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]) ] return np.sum(edge)/2. def parse_annotation_8(all_img_path, all_tag_path): all_imgs_path, all_tags_path, all_text_tags = [],[],[] max_boxes = 0 all_imgs_boxes = [] for index,ann_path in enumerate(all_tag_path): boxes_list = [] text_tags = [] boxes_counter = 0 tag = ann_path.split(r'.')[-1] if tag == 'txt' or tag == 'py': img = cv2.imread(all_img_path[index]) try: h,w = img.shape[0],img.shape[1] except: print(ann_path) raise with open(ann_path,"r") as f: lines_content = f.readlines() for i in lines_content: object_info = [0.,0.,0.,0.,0.,0.,0.,0.,0.] splits = i.strip().split(',') if i.strip() == '': continue try: cls_label = splits[8:] cls_label_hegd = '' for ii in cls_label: cls_label_hegd = cls_label_hegd + ',' + ii cls_label = cls_label_hegd[1:] except: print(len(splits)) continue try: lan_label = splits[8].strip() except: continue if len(splits) >= 10: cls_label = i.strip().split(lan_label+',')[-1].strip() if lan_label == 'Mixed' or lan_label == 'None' or lan_label == 'Chinese' or lan_label == 'Japanese' or lan_label == 'Korean' or lan_label == 'Bangla': continue if (len(splits) == 9 or (len(splits) >= 10 and (lan_label=='Latin' or lan_label=='Symbols'))) and cls_label!='###': object_info[0] = round(float(splits[0].strip())) object_info[1] = round(float(splits[1].strip())) object_info[2] = round(float(splits[2].strip())) object_info[3] = round(float(splits[3].strip())) object_info[4] = round(float(splits[4].strip())) object_info[5] = round(float(splits[5].strip())) object_info[6] = round(float(splits[6].strip())) object_info[7] = round(float(splits[7].strip())) object_info[8] = 1 poly = np.array(object_info)[:-1].reshape((4,2)) if polygon_area(poly) > 0: poly = poly[(0, 3, 2, 1), :] object_info[0] = poly[0,0] object_info[1] = poly[0,1] object_info[2] = poly[1,0] object_info[3] = poly[1,1] object_info[4] = poly[2,0] object_info[5] = poly[2,1] object_info[6] = poly[3,0] object_info[7] = poly[3,1] while_cacu = 0 while object_info[2] <= object_info[0] or abs(object_info[2]-object_info[0]) <= abs(object_info[6]-object_info[0]): while_cacu += 1 object_info[:-3],object_info[-3:-1] = object_info[2:-1],object_info[:2] if while_cacu > 4: break poly = np.array(object_info)[:-1].reshape((4,2)) poly[:, 0] = np.clip(poly[:, 0], 0, w-1) poly[:, 1] = np.clip(poly[:, 1], 0, h-1) if abs(polygon_area(poly)) < 1: continue boxes_list.append(object_info) text_tags.append(cls_label) boxes_counter += 1 else: pass if boxes_counter > max_boxes: max_boxes = boxes_counter if tag == 'json': img = cv2.imread(all_img_path[index]) try: h,w = img.shape[0],img.shape[1] except: print(ann_path) raise with open(ann_path,"r") as f: try: file_content = f.read() except: with open(ann_path,"r",encoding='iso8859-1') as ff_hegd: file_content = ff_hegd.read() try: json_content = json.loads(file_content)['shapes'] except: json_content = json.loads(file_content)['Public'][0]['Landmark'] for i in json_content: object_info = [0.,0.,0.,0.,0.,0.,0.,0.,0.] flag_hegd = 0 try: pos = np.array(i['points']).flatten() except: pos = np.array(i['Points']).flatten() flag_hegd = 1 try: cls_label = i['text'] except: try: cls_label = i['label'] except: try: cls_label = i['txt'] except: continue if len(pos) >= 4 and len(pos) < 8: if flag_hegd == 1: pos_0 = pos[0] pos_1 = pos[1] pos_2 = pos[2] pos_3 = pos[3] object_info[0] = round(float(pos_0['X'])) object_info[1] = round(float(pos_0['Y'])) object_info[2] = round(float(pos_1['X'])) object_info[3] = round(float(pos_1['Y'])) object_info[4] = round(float(pos_2['X'])) object_info[5] = round(float(pos_2['Y'])) object_info[6] = round(float(pos_3['X'])) object_info[7] = round(float(pos_3['Y'])) else: object_info[0] = round(float(pos[0])) object_info[1] = round(float(pos[1])) object_info[2] = round(float(pos[2])) object_info[3] = round(float(pos[1])) object_info[4] = round(float(pos[2])) object_info[5] = round(float(pos[3])) object_info[6] = round(float(pos[0])) object_info[7] = round(float(pos[3])) elif len(pos) >= 8: object_info[0] = round(float(pos[0])) object_info[1] = round(float(pos[1])) object_info[2] = round(float(pos[2])) object_info[3] = round(float(pos[3])) object_info[4] = round(float(pos[4])) object_info[5] = round(float(pos[5])) object_info[6] = round(float(pos[6])) object_info[7] = round(float(pos[7])) pass else: continue object_info[8] = 1 poly = np.array(object_info)[:-1].reshape((4,2)) if polygon_area(poly) > 0: poly = poly[(0, 3, 2, 1), :] object_info[0] = poly[0,0] object_info[1] = poly[0,1] object_info[2] = poly[1,0] object_info[3] = poly[1,1] object_info[4] = poly[2,0] object_info[5] = poly[2,1] object_info[6] = poly[3,0] object_info[7] = poly[3,1] while_cacu = 0 while object_info[2] <= object_info[0] or abs(object_info[2]-object_info[0]) <= abs(object_info[6]-object_info[0]): while_cacu += 1 object_info[:-3],object_info[-3:-1] = object_info[2:-1],object_info[:2] if while_cacu > 4: break poly = np.array(object_info)[:-1].reshape((4,2)) poly[:, 0] = np.clip(poly[:, 0], 0, w-1) poly[:, 1] = np.clip(poly[:, 1], 0, h-1) if abs(polygon_area(poly)) < 1: continue boxes_list.append(object_info) text_tags.append(cls_label) boxes_counter += 1 if boxes_counter > max_boxes: max_boxes = boxes_counter if tag == 'xml': img = cv2.imread(all_img_path[index]) try: h,w = img.shape[0],img.shape[1] except: print(ann_path) raise try: tree = ET.parse(ann_path) except: continue for elem in tree.iter(tag='object'): for attr in list(elem): object_info = [0.,0.,0.,0.,0.,0.,0.,0.,0.] if 'name' in attr.tag: try: cls_label = attr.text.strip() except: continue if 'bndbox' in attr.tag: for pos in list(attr): if 'xmin' in pos.tag: object_info[0] = round(float(pos.text.strip())) if 'ymin' in pos.tag: object_info[1] = round(float(pos.text.strip())) if 'xmax' in pos.tag: object_info[4] = round(float(pos.text.strip())) if 'ymax' in pos.tag: object_info[5] = round(float(pos.text.strip())) object_info[2] = object_info[0] object_info[3] = object_info[5] object_info[6] = object_info[4] object_info[7] = object_info[1] if 'polygon' in attr.tag: for pos in list(attr): if 'x1' in pos.tag: object_info[0] = round(float(pos.text.strip())) if 'y1' in pos.tag: object_info[1] = round(float(pos.text.strip())) if 'x2' in pos.tag: object_info[2] = round(float(pos.text.strip())) if 'y2' in pos.tag: object_info[3] = round(float(pos.text.strip())) if 'x3' in pos.tag: object_info[4] = round(float(pos.text.strip())) if 'y3' in pos.tag: object_info[5] = round(float(pos.text.strip())) if 'x4' in pos.tag: object_info[6] = round(float(pos.text.strip())) if 'y4' in pos.tag: object_info[7] = round(float(pos.text.strip())) object_info[8] = 1 object_info_tmp = object_info.copy() poly = np.array(object_info)[:-1].reshape((4,2)) if polygon_area(poly) > 0: poly = poly[(0, 3, 2, 1), :] object_info[0] = poly[0,0] object_info[1] = poly[0,1] object_info[2] = poly[1,0] object_info[3] = poly[1,1] object_info[4] = poly[2,0] object_info[5] = poly[2,1] object_info[6] = poly[3,0] object_info[7] = poly[3,1] while_cacu = 0 while object_info[2] <= object_info[0] or abs(object_info[2]-object_info[0]) <= abs(object_info[6]-object_info[0]): while_cacu += 1 object_info[:-3],object_info[-3:-1] = object_info[2:-1],object_info[:2] if while_cacu > 4: break poly = np.array(object_info)[:-1].reshape((4,2)) poly[:, 0] = np.clip(poly[:, 0], 0, w-1) poly[:, 1] = np.clip(poly[:, 1], 0, h-1) if abs(polygon_area(poly)) < 1: continue boxes_list.append(object_info) text_tags.append(cls_label) boxes_counter += 1 if boxes_counter > max_boxes: max_boxes = boxes_counter for elem in tree.iter(tag='item'): for attr in list(elem): object_info = [0.,0.,0.,0.,0.,0.,0.,0.,0.] if 'name' in attr.tag: try: cls_label = attr.text.strip() except: continue if 'bndbox' in attr.tag: for pos in list(attr): if 'xmin' in pos.tag: object_info[0] = round(float(pos.text.strip())) if 'ymin' in pos.tag: object_info[1] = round(float(pos.text.strip())) if 'xmax' in pos.tag: object_info[4] = round(float(pos.text.strip())) if 'ymax' in pos.tag: object_info[5] = round(float(pos.text.strip())) object_info[2] = object_info[0] object_info[3] = object_info[5] object_info[6] = object_info[4] object_info[7] = object_info[1] if 'polygon' in attr.tag: for pos in list(attr): if 'x1' in pos.tag: object_info[0] = round(float(pos.text.strip())) if 'y1' in pos.tag: object_info[1] = round(float(pos.text.strip())) if 'x2' in pos.tag: object_info[2] = round(float(pos.text.strip())) if 'y2' in pos.tag: object_info[3] = round(float(pos.text.strip())) if 'x3' in pos.tag: object_info[4] = round(float(pos.text.strip())) if 'y3' in pos.tag: object_info[5] = round(float(pos.text.strip())) if 'x4' in pos.tag: object_info[6] = round(float(pos.text.strip())) if 'y4' in pos.tag: object_info[7] = round(float(pos.text.strip())) object_info[8] = 1 object_info_tmp = object_info.copy() poly = np.array(object_info)[:-1].reshape((4,2)) if polygon_area(poly) > 0: poly = poly[(0, 3, 2, 1), :] object_info[0] = poly[0,0] object_info[1] = poly[0,1] object_info[2] = poly[1,0] object_info[3] = poly[1,1] object_info[4] = poly[2,0] object_info[5] = poly[2,1] object_info[6] = poly[3,0] object_info[7] = poly[3,1] while_cacu = 0 while object_info[2] <= object_info[0] or abs(object_info[2]-object_info[0]) <= abs(object_info[6]-object_info[0]): while_cacu += 1 object_info[:-3],object_info[-3:-1] = object_info[2:-1],object_info[:2] if while_cacu > 4: break poly = np.array(object_info)[:-1].reshape((4,2)) poly[:, 0] = np.clip(poly[:, 0], 0, w-1) poly[:, 1] = np.clip(poly[:, 1], 0, h-1) if abs(polygon_area(poly)) < 1: continue boxes_list.append(object_info) text_tags.append(cls_label) boxes_counter += 1 if boxes_counter > max_boxes: max_boxes = boxes_counter if len(boxes_list) == 0 or all_img_path[index].split('.')[-1] == 'gif': continue else: all_imgs_path.append(all_img_path[index]) all_tags_path.append(ann_path) all_imgs_boxes.append(boxes_list) all_text_tags.append(text_tags) boxes = np.zeros([len(all_tags_path), max_boxes, 9]) for i in range(len(all_tags_path)): boxes_rec = np.array(all_imgs_boxes[i]) boxes[i,:boxes_rec.shape[0],:] = boxes_rec boxes = boxes.astype(int) return all_imgs_path, boxes, all_text_tags def write_img_infos(rootdir): ''' 写all_img_path_rec.txt,all_boxes_rec.txt,all_text_tag_rec文件 :param root_path: 保存.txt文件路径 :return: ''' all_img_path, all_tag_path = get_all_paths(rootdir) imgs,boxes,text_tags = parse_annotation_8(all_img_path, all_tag_path) # root_path = 'd:/Users/Administrator/Desktop/liudan/ocr/data/ocr_txt/' root_path = os.path.join(total_path, 'data/ocr_txt/') if not os.path.exists(root_path): os.makedirs(root_path) reName = rootdir.split('/')[-1] with open(root_path + reName+'_img_path_rec.txt', 'w') as f: for i in imgs: f.write(i+'\n') with open(root_path + reName + '_boxes_rec.txt', 'w') as f: boxes = boxes.flatten() for i in boxes: f.write(str(i)+'\t') with open(root_path + reName + '_text_tag_rec.txt', 'w') as f: for i in text_tags: for kk in i: if kk == None: kk = "None_hegd" f.write(kk.strip()+'\t**hegd**\t') f.write('\n') def simple_load_np_dataset(rootdir): ''' :param root_path: 存放all_img_path_rec.txt,all_boxes_rec.txt,all_text_tag_rec位置 :return: 读.txt文件 ''' # root_path = 'd:/Users/Administrator/Desktop/liudan/ocr/data/ocr_txt/' root_path = os.path.join(total_path, 'data/ocr_txt/') if not os.path.exists(root_path): os.makedirs(root_path) reName = rootdir.split('/')[-1] img_list_rec_file = root_path + reName + '_img_path_rec.txt' boxes_rec_file = root_path + reName + '_boxes_rec.txt' text_tag_rec_file = root_path + reName + '_text_tag_rec.txt' all_img_path = [] with open(img_list_rec_file,"r") as f: file_content = f.readlines() for i in file_content: all_img_path.append(i.strip()) with open(boxes_rec_file,'r') as f: file_content = f.read().strip() num_list = file_content.split('\t') boxes_flatten = [] for i in num_list: boxes_flatten.append(int(i)) boxes_flatten = np.array(boxes_flatten) boxes = boxes_flatten.reshape((len(all_img_path),-1, 9)) with open(text_tag_rec_file,'r') as f: all_text_tags = [] file_content = f.readlines() for i in file_content: all_text_tags.append(i.split('\t**hegd**\t')[:-1]) return all_img_path,boxes,all_text_tags import time def batch_save_txt(rootdir): file_object_list, file_image_list = find(rootdir) # for i in file_image_list: # if i.endswith('.json'): # f2.write(i[:-4] + 'jpg\n') # 新 img_file_list = [] label_file_list = [] for i, j in zip(file_image_list, file_object_list): if not os.path.exists(i.strip()): print(j.strip()) continue img_file_list.append(i) label_file_list.append(j) print('total imgs num: ', len(file_image_list)) write_img_infos(rootdir) p = 1 all_img_path, boxes, all_text_tags = simple_load_np_dataset(rootdir) print(all_img_path[p - 1]) print(boxes[p - 1]) print(all_text_tags[p - 1]) sum_box = 0 for idx in range(len(all_img_path)): gt_img_path = all_img_path[idx] gt_boxs = boxes[idx] if gt_boxs.shape[-1] == 9: boxs_tmp = [] for i in gt_boxs: if i[-1] > 0.5: boxs_tmp.append(i) gt_boxs = np.array(boxs_tmp) gt_text_tags = all_text_tags[idx] path_post = gt_img_path.strip(). \ split(rootdir)[-1] # dir = 'd:/Users/Administrator/Desktop/liudan/ocr/data/' # dir = os.path.join(total_path, 'data/') if rootdir == os.path.join(dir, 'total_data'): save_path = os.path.join(dir, 'total_transform_imgs') elif rootdir == os.path.join(dir, 'train'): save_path = os.path.join(dir, 'train_transform_imgs') elif rootdir == os.path.join(dir, 'val'): save_path = os.path.join(dir, 'val_transform_imgs') for ii in path_post.split('/')[1:-1]: save_path = os.path.join(save_path, ii) if not os.path.exists(save_path): os.makedirs(save_path) num_box = save_images(gt_boxs, gt_text_tags, gt_img_path, save_path) sum_box += num_box print(sum_box) if __name__ == "__main__": rootdir = os.path.join(dir, 'total_data') batch_save_txt(rootdir) rootdir = os.path.join(dir, 'train') batch_save_txt(rootdir) rootdir = os.path.join(dir, 'val') batch_save_txt(rootdir)