prepare_data_JTLEE.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import numpy as np
  2. import cv2
  3. from PIL import Image
  4. import argparse
  5. import os, sys
  6. from os.path import join, split, splitext, abspath, isfile
  7. sys.path.insert(0, abspath(".."))
  8. sys.path.insert(0, abspath("."))
  9. from mask_component_utils import Line, LineAnnotation, line2hough
  10. import matplotlib
  11. import matplotlib.pyplot as plt
  12. from skimage.measure import label, regionprops
  13. parser = argparse.ArgumentParser(description="Prepare semantic line data format.")
  14. parser.add_argument('--root', type=str, required=True, help='the data root dir.')
  15. parser.add_argument('--label', type=str, required=True, help='the label root dir.')
  16. parser.add_argument('--num_directions', type=int, default=12, help='the division of semicircular angle')
  17. parser.add_argument('--list', type=str, required=True, help='list file')
  18. parser.add_argument('--save-dir', type=str, required=True, help='save-dir')
  19. parser.add_argument('--prefix', type=str, default="", help="Prefix in list file")
  20. parser.add_argument('--fixsize', type=int, default=None, help='fix resize of images and annotations')
  21. parser.add_argument('--numangle', type=int, default=80, required=True)
  22. parser.add_argument('--numrho', type=int, default=80, required=True)
  23. args = parser.parse_args()
  24. label_path = abspath(args.label)
  25. image_dir = abspath(args.root)
  26. save_dir = abspath(args.save_dir)
  27. os.makedirs(save_dir, exist_ok=True)
  28. def nearest8(x):
  29. return int(np.round(x/8)*8)
  30. def vis_anno(image, annotation):
  31. mask = annotation.oriental_mask()
  32. mask_sum = mask.sum(axis=0).astype(bool)
  33. image_cp = image.copy()
  34. image_cp[mask_sum, ...] = [0, 255, 0]
  35. mask = np.zeros((image.shape[0], image.shape[1]))
  36. mask[mask_sum] = 1
  37. return image_cp, mask
  38. labels_files = [i for i in os.listdir(label_path) if i.endswith(".txt")]
  39. num_samples = len(labels_files)
  40. filelist = open(args.list, "w")
  41. stastic = np.zeros(10)
  42. for idx, label_file in enumerate(labels_files):
  43. filename, _ = splitext(label_file)
  44. print("Processing %s [%d/%d]..." % (filename, idx+1, len(labels_files)))
  45. if isfile(join(image_dir, filename+".jpg")):
  46. im = cv2.imread(join(image_dir, filename+".jpg"))
  47. im = cv2.resize(im, (args.fixsize, args.fixsize))
  48. else:
  49. print("Warning: image %s doesnt exist!" % join(image_dir, filename+".jpg"))
  50. continue
  51. for argument in range(2):
  52. if argument == 0:
  53. H, W, _ = im.shape
  54. lines = []
  55. with open(join(label_path, label_file)) as f:
  56. data = f.readlines()
  57. nums = len(data)
  58. stastic[nums] += 1
  59. for line in data:
  60. data1 = line.strip().split(',')
  61. if len(data1) <= 4:
  62. continue
  63. data1 = [int(float(x)) for x in data1]
  64. if data1[1]==data1[3] and data1[0]==data1[2]:
  65. continue
  66. line = Line([data1[1], data1[0], data1[3], data1[2]])
  67. lines.append(line)
  68. annotation = LineAnnotation(size=[H, W], divisions=args.num_directions, lines=lines)
  69. else:
  70. im = cv2.flip(im, 1)
  71. filename = filename + '_flip'
  72. H, W, _ = im.shape
  73. lines = []
  74. with open(join(label_path, label_file)) as f:
  75. data = f.readlines()
  76. for line in data:
  77. data1 = line.strip().split(',')
  78. if len(data1) <= 4:
  79. continue
  80. data1 = [int(float(x)) for x in data1]
  81. if data1[1]==data1[3] and data1[0]==data1[2]:
  82. continue
  83. line = Line([data1[1], W-1-data1[0], data1[3], W-1-data1[2]])
  84. lines.append(line)
  85. annotation = LineAnnotation(size=[H, W], divisions=args.num_directions, lines=lines)
  86. # resize image and annotations
  87. if args.fixsize is not None:
  88. newH = nearest8(args.fixsize)
  89. newW = nearest8(args.fixsize)
  90. else:
  91. newH = nearest8(H)
  92. newW = nearest8(W)
  93. im = cv2.resize(im, (newW, newH))
  94. annotation.resize(size=[newH, newW])
  95. vis, mask = vis_anno(im, annotation)
  96. hough_space_label = np.zeros((args.numangle, args.numrho))
  97. for l in annotation.lines:
  98. theta, r = line2hough(l, numAngle=args.numangle, numRho=args.numrho, size=(newH, newW))
  99. hough_space_label[theta, r] += 1
  100. hough_space_label = cv2.GaussianBlur(hough_space_label, (5,5), 0)
  101. if hough_space_label.max() > 0:
  102. hough_space_label = hough_space_label / hough_space_label.max()
  103. gt_coords = []
  104. for l in annotation.lines:
  105. gt_coords.append(l.coord)
  106. gt_coords = np.array(gt_coords)
  107. data = dict({
  108. "hough_space_label8": hough_space_label,
  109. "coords": gt_coords
  110. })
  111. save_name = os.path.join(save_dir, filename)
  112. np.save(save_name, data)
  113. cv2.imwrite(save_name + '.jpg', im)
  114. cv2.imwrite(save_name + '_p_label.jpg', hough_space_label*255)
  115. cv2.imwrite(save_name + '_vis.jpg', vis)
  116. cv2.imwrite(save_name + '_mask.jpg', mask*255)
  117. filelist.close()
  118. # print(stastic)