prepare_data_NKL.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import numpy as np
  2. import cv2
  3. from PIL import Image
  4. import argparse
  5. import os, sys
  6. from os.path import join, split, splitext, abspath, isfile
  7. sys.path.insert(0, abspath(".."))
  8. sys.path.insert(0, abspath("."))
  9. from mask_component_utils import Line, LineAnnotation, line2hough
  10. import matplotlib
  11. import matplotlib.pyplot as plt
  12. from skimage.measure import label, regionprops
  13. parser = argparse.ArgumentParser(description="Prepare semantic line data format.")
  14. parser.add_argument('--root', type=str, required=True, help='the data root dir.')
  15. parser.add_argument('--label', type=str, required=True, help='the label root dir.')
  16. parser.add_argument('--save-dir', type=str, required=True, help='save-dir')
  17. parser.add_argument('--fixsize', type=int, default=None, help='fix resize of images and annotations')
  18. parser.add_argument('--numangle', type=int, default=100)
  19. parser.add_argument('--numrho', type=int, default=100)
  20. args = parser.parse_args()
  21. label_path = abspath(args.label)
  22. image_dir = abspath(args.root)
  23. save_dir = abspath(args.save_dir)
  24. os.makedirs(save_dir, exist_ok=True)
  25. def nearest8(x):
  26. return int(np.round(x/8)*8)
  27. def vis_anno(image, annotation):
  28. mask = annotation.oriental_mask()
  29. mask_sum = mask.sum(axis=0).astype(bool)
  30. image_cp = image.copy()
  31. image_cp[mask_sum, ...] = [0, 255, 0]
  32. mask = np.zeros((image.shape[0], image.shape[1]))
  33. mask[mask_sum] = 1
  34. return image_cp, mask
  35. def check(y, x, H, W):
  36. x = max(0, x)
  37. y = max(0, y)
  38. x = min(x, W-1)
  39. y = min(y, H-1)
  40. return y, x
  41. labels_files = [i for i in os.listdir(label_path) if i.endswith(".txt")]
  42. num_samples = len(labels_files)
  43. stastic = np.zeros(20)
  44. total_nums = 0
  45. angle_stastic = np.zeros(180)
  46. total_lines = 0
  47. for idx, label_file in enumerate(labels_files):
  48. filename, _ = splitext(label_file)
  49. print("Processing %s [%d/%d]..." % (filename, idx+1, len(labels_files)))
  50. if isfile(join(image_dir, filename+".jpg")):
  51. im = cv2.imread(join(image_dir, filename+".jpg"))
  52. H, W = im.shape[0], im.shape[1]
  53. scale_H, scale_W = args.fixsize / H, args.fixsize / W
  54. im = cv2.resize(im, (args.fixsize, args.fixsize))
  55. else:
  56. print("Warning: image %s doesnt exist!" % join(image_dir, filename+".jpg"))
  57. continue
  58. for argument in range(2):
  59. if argument == 0:
  60. lines = []
  61. with open(join(label_path, label_file)) as f:
  62. data = f.readlines()[0].split(' ')
  63. nums = int(data[0])
  64. stastic[nums] += 1
  65. total_nums += nums
  66. if int(nums) == 0:
  67. print("Warning: image has no semantic line : %s" % (filename))
  68. for i in range(nums):
  69. y1, x1 = check(int(data[i*4+2]), int(data[i*4+1]), H, W)
  70. y2, x2 = check(int(data[i*4+4]), int(data[i*4+3]), H, W)
  71. line = Line([y1, x1, y2, x2])
  72. angle = line.angle()
  73. angle_stastic[int((angle / np.pi + 0.5) * 180)] += 1
  74. total_lines += 1
  75. line.rescale(scale_H, scale_W)
  76. lines.append(line)
  77. annotation = LineAnnotation(size=[args.fixsize, args.fixsize], lines=lines)
  78. else:
  79. im = cv2.flip(im, 1)
  80. filename = filename + '_flip'
  81. lines = []
  82. with open(join(label_path, label_file)) as f:
  83. data = f.readlines()[0].split(' ')
  84. for i in range(int(data[0])):
  85. y1, x1 = check(int(data[i*4+2]), W-1-int(data[i*4+1]), H, W)
  86. y2, x2 = check(int(data[i*4+4]), W-1-int(data[i*4+3]), H, W)
  87. line = Line([y1, x1, y2, x2])
  88. line.rescale(scale_H, scale_W)
  89. lines.append(line)
  90. annotation = LineAnnotation(size=[args.fixsize, args.fixsize], lines=lines)
  91. # resize image and annotations
  92. if args.fixsize is not None:
  93. newH = nearest8(args.fixsize)
  94. newW = nearest8(args.fixsize)
  95. else:
  96. newH = nearest8(H)
  97. newW = nearest8(W)
  98. im = cv2.resize(im, (newW, newH))
  99. annotation.resize(size=[newH, newW])
  100. vis, mask = vis_anno(im, annotation)
  101. hough_space_label = np.zeros((args.numangle, args.numrho))
  102. for l in annotation.lines:
  103. theta, r = line2hough(l, numAngle=args.numangle, numRho=args.numrho, size=(newH, newW))
  104. hough_space_label[theta, r] += 1
  105. hough_space_label = cv2.GaussianBlur(hough_space_label, (5,5), 0)
  106. if hough_space_label.max() > 0:
  107. hough_space_label = hough_space_label / hough_space_label.max()
  108. gt_coords = []
  109. for l in annotation.lines:
  110. gt_coords.append(l.coord)
  111. gt_coords = np.array(gt_coords)
  112. data = dict({
  113. "hough_space_label8": hough_space_label,
  114. "coords": gt_coords
  115. })
  116. save_name = os.path.join(save_dir, filename)
  117. np.save(save_name, data)
  118. cv2.imwrite(save_name + '.jpg', im)
  119. # cv2.imwrite(save_name + '_p_label.jpg', hough_space_label*255)
  120. # cv2.imwrite(save_name + '_vis.jpg', vis)
  121. cv2.imwrite(save_name + '_mask.jpg', mask*255)
  122. #print(stastic)
  123. #print(angle_stastic.astype(np.int))