bdd100k2mot.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import glob
  15. import os
  16. import os.path as osp
  17. import cv2
  18. import random
  19. import numpy as np
  20. import argparse
  21. import tqdm
  22. import json
  23. def mkdir_if_missing(d):
  24. if not osp.exists(d):
  25. os.makedirs(d)
  26. def bdd2mot_tracking(img_dir, label_dir, save_img_dir, save_label_dir):
  27. label_jsons = os.listdir(label_dir)
  28. for label_json in tqdm(label_jsons):
  29. with open(os.path.join(label_dir, label_json)) as f:
  30. labels_json = json.load(f)
  31. for label_json in labels_json:
  32. img_name = label_json['name']
  33. video_name = label_json['videoName']
  34. labels = label_json['labels']
  35. txt_string = ""
  36. for label in labels:
  37. category = label['category']
  38. x1 = label['box2d']['x1']
  39. x2 = label['box2d']['x2']
  40. y1 = label['box2d']['y1']
  41. y2 = label['box2d']['y2']
  42. width = x2 - x1
  43. height = y2 - y1
  44. x_center = (x1 + x2) / 2. / args.width
  45. y_center = (y1 + y2) / 2. / args.height
  46. width /= args.width
  47. height /= args.height
  48. identity = int(label['id'])
  49. # [class] [identity] [x_center] [y_center] [width] [height]
  50. txt_string += "{} {} {} {} {} {}\n".format(
  51. attr_id_dict[category], identity, x_center, y_center,
  52. width, height)
  53. fn_label = os.path.join(save_label_dir, img_name[:-4] + '.txt')
  54. source_img = os.path.join(img_dir, video_name, img_name)
  55. target_img = os.path.join(save_img_dir, img_name)
  56. with open(fn_label, 'w') as f:
  57. f.write(txt_string)
  58. os.system('cp {} {}'.format(source_img, target_img))
  59. def transBbox(bbox):
  60. # bbox --> cx cy w h
  61. bbox = list(map(lambda x: float(x), bbox))
  62. bbox[0] = (bbox[0] - bbox[2] / 2) * 1280
  63. bbox[1] = (bbox[1] - bbox[3] / 2) * 720
  64. bbox[2] = bbox[2] * 1280
  65. bbox[3] = bbox[3] * 720
  66. bbox = list(map(lambda x: str(x), bbox))
  67. return bbox
  68. def genSingleImageMot(inputPath, classes=[]):
  69. labelPaths = glob.glob(inputPath + '/*.txt')
  70. labelPaths = sorted(labelPaths)
  71. allLines = []
  72. result = {}
  73. for labelPath in labelPaths:
  74. frame = str(int(labelPath.split('-')[-1].replace('.txt', '')))
  75. with open(labelPath, 'r') as labelPathFile:
  76. lines = labelPathFile.readlines()
  77. for line in lines:
  78. line = line.replace('\n', '')
  79. lineArray = line.split(' ')
  80. if len(classes) > 0:
  81. if lineArray[0] in classes:
  82. lineArray.append(frame)
  83. allLines.append(lineArray)
  84. else:
  85. lineArray.append(frame)
  86. allLines.append(lineArray)
  87. resultMap = {}
  88. for line in allLines:
  89. if line[1] not in resultMap.keys():
  90. resultMap[line[1]] = []
  91. resultMap[line[1]].append(line)
  92. mot_gt = []
  93. id_idx = 0
  94. for rid in resultMap.keys():
  95. id_idx += 1
  96. for id_line in resultMap[rid]:
  97. mot_line = []
  98. mot_line.append(id_line[-1])
  99. mot_line.append(str(id_idx))
  100. id_line_temp = transBbox(id_line[2:6])
  101. mot_line.extend(id_line_temp)
  102. mot_line.append('1') # origin class: id_line[0]
  103. mot_line.append('1') # permanent class => 1
  104. mot_line.append('1')
  105. mot_gt.append(mot_line)
  106. result = list(map(lambda line: str.join(',', line), mot_gt))
  107. resultStr = str.join('\n', result)
  108. return resultStr
  109. def writeGt(inputPath, outPath, classes=[]):
  110. singleImageResult = genSingleImageMot(inputPath, classes=classes)
  111. outPathFile = outPath + '/gt.txt'
  112. mkdir_if_missing(outPath)
  113. with open(outPathFile, 'w') as gtFile:
  114. gtFile.write(singleImageResult)
  115. def genSeqInfo(seqInfoPath):
  116. name = seqInfoPath.split('/')[-2]
  117. img1Path = osp.join(str.join('/', seqInfoPath.split('/')[0:-1]), 'img1')
  118. seqLength = len(glob.glob(img1Path + '/*.jpg'))
  119. seqInfoStr = f'''[Sequence]\nname={name}\nimDir=img1\nframeRate=30\nseqLength={seqLength}\nimWidth=1280\nimHeight=720\nimExt=.jpg'''
  120. with open(seqInfoPath, 'w') as seqFile:
  121. seqFile.write(seqInfoStr)
  122. def genMotGt(dataDir, classes=[]):
  123. seqLists = sorted(glob.glob(dataDir))
  124. for seqList in seqLists:
  125. inputPath = osp.join(seqList, 'img1')
  126. outputPath = seqList.replace('labels_with_ids', 'images')
  127. outputPath = osp.join(outputPath, 'gt')
  128. mkdir_if_missing(outputPath)
  129. print('processing...', outputPath)
  130. writeGt(inputPath, outputPath, classes=classes)
  131. seqList = seqList.replace('labels_with_ids', 'images')
  132. seqInfoPath = osp.join(seqList, 'seqinfo.ini')
  133. genSeqInfo(seqInfoPath)
  134. def updateSeqInfo(dataDir, phase):
  135. seqPath = osp.join(dataDir, 'labels_with_ids', phase)
  136. seqList = glob.glob(seqPath + '/*')
  137. for seqName in seqList:
  138. print('seqName=>', seqName)
  139. seqName_img1_dir = osp.join(seqName, 'img1')
  140. txtLength = glob.glob(seqName_img1_dir + '/*.txt')
  141. name = seqName.split('/')[-1].replace('.jpg', '').replace('.txt', '')
  142. seqLength = len(txtLength)
  143. seqInfoStr = f'''[Sequence]\nname={name}\nimDir=img1\nframeRate=30\nseqLength={seqLength}\nimWidth=1280\nimHeight=720\nimExt=.jpg'''
  144. seqInfoPath = seqName_img1_dir.replace('labels_with_ids', 'images')
  145. seqInfoPath = seqInfoPath.replace('/img1', '')
  146. seqInfoPath = seqInfoPath + '/seqinfo.ini'
  147. with open(seqInfoPath, 'w') as seqFile:
  148. seqFile.write(seqInfoStr)
  149. def VisualDataset(datasetPath, phase='train', seqName='', frameId=1):
  150. trainPath = osp.join(datasetPath, 'labels_with_ids', phase)
  151. seq1Paths = osp.join(trainPath, seqName)
  152. seq_img1_path = osp.join(seq1Paths, 'img1')
  153. label_with_idPath = osp.join(seq_img1_path, seqName + '-' + '%07d' %
  154. frameId) + '.txt'
  155. image_path = label_with_idPath.replace('labels_with_ids', 'images').replace(
  156. '.txt', '.jpg')
  157. seqInfoPath = str.join('/', image_path.split('/')[:-2])
  158. seqInfoPath = seqInfoPath + '/seqinfo.ini'
  159. seq_info = open(seqInfoPath).read()
  160. width = int(seq_info[seq_info.find('imWidth=') + 8:seq_info.find(
  161. '\nimHeight')])
  162. height = int(seq_info[seq_info.find('imHeight=') + 9:seq_info.find(
  163. '\nimExt')])
  164. with open(label_with_idPath, 'r') as label:
  165. allLines = label.readlines()
  166. images = cv2.imread(image_path)
  167. print('image_path => ', image_path)
  168. for line in allLines:
  169. line = line.split(' ')
  170. line = list(map(lambda x: float(x), line))
  171. c1, c2, w, h = line[2:6]
  172. x1 = c1 - w / 2
  173. x2 = c2 - h / 2
  174. x3 = c1 + w / 2
  175. x4 = c2 + h / 2
  176. cv2.rectangle(
  177. images, (int(x1 * width), int(x2 * height)),
  178. (int(x3 * width), int(x4 * height)), (255, 0, 0),
  179. thickness=2)
  180. cv2.imwrite('test.jpg', images)
  181. def VisualGt(dataPath, phase='train'):
  182. seqList = sorted(glob.glob(osp.join(dataPath, 'images', phase) + '/*'))
  183. seqIndex = random.randint(0, len(seqList) - 1)
  184. seqPath = seqList[seqIndex]
  185. gt_path = osp.join(seqPath, 'gt', 'gt.txt')
  186. img_list_path = sorted(glob.glob(osp.join(seqPath, 'img1', '*.jpg')))
  187. imgIndex = random.randint(0, len(img_list_path))
  188. img_Path = img_list_path[imgIndex]
  189. frame_value = img_Path.split('/')[-1].replace('.jpg', '')
  190. frame_value = frame_value.split('-')[-1]
  191. frame_value = int(frame_value)
  192. seqNameStr = img_Path.split('/')[-1].replace('.jpg', '').replace('img', '')
  193. frame_value = int(seqNameStr.split('-')[-1])
  194. print('frame_value => ', frame_value)
  195. gt_value = np.loadtxt(gt_path, dtype=float, delimiter=',')
  196. gt_value = gt_value[gt_value[:, 0] == frame_value]
  197. get_list = gt_value.tolist()
  198. img = cv2.imread(img_Path)
  199. colors = [[255, 0, 0], [255, 255, 0], [255, 0, 255], [0, 255, 0],
  200. [0, 255, 255], [0, 0, 255]]
  201. for seq, _id, pl, pt, w, h, _, bbox_class, _ in get_list:
  202. pl, pt, w, h = int(pl), int(pt), int(w), int(h)
  203. print('pl,pt,w,h => ', pl, pt, w, h)
  204. cv2.putText(img,
  205. str(bbox_class), (pl, pt), cv2.FONT_HERSHEY_PLAIN, 2,
  206. colors[int(bbox_class - 1)])
  207. cv2.rectangle(
  208. img, (pl, pt), (pl + w, pt + h),
  209. colors[int(bbox_class - 1)],
  210. thickness=2)
  211. cv2.imwrite('testGt.jpg', img)
  212. print(seqPath, frame_value)
  213. return seqPath.split('/')[-1], frame_value
  214. def gen_image_list(dataPath, datType):
  215. inputPath = f'{dataPath}/labels_with_ids/{datType}'
  216. pathList = sorted(glob.glob(inputPath + '/*'))
  217. print(pathList)
  218. allImageList = []
  219. for pathSingle in pathList:
  220. imgList = sorted(glob.glob(osp.join(pathSingle, 'img1', '*.txt')))
  221. for imgPath in imgList:
  222. imgPath = imgPath.replace('labels_with_ids', 'images').replace(
  223. '.txt', '.jpg')
  224. allImageList.append(imgPath)
  225. with open(f'{dataPath}.{datType}', 'w') as image_list_file:
  226. allImageListStr = str.join('\n', allImageList)
  227. image_list_file.write(allImageListStr)
  228. def formatOrigin(datapath, phase):
  229. label_with_idPath = osp.join(datapath, 'labels_with_ids', phase)
  230. print(label_with_idPath)
  231. for txtList in sorted(glob.glob(label_with_idPath + '/*.txt')):
  232. print(txtList)
  233. seqName = txtList.split('/')[-1]
  234. seqName = str.join('-', seqName.split('-')[0:-1]).replace('.txt', '')
  235. seqPath = osp.join(label_with_idPath, seqName, 'img1')
  236. mkdir_if_missing(seqPath)
  237. os.system(f'mv {txtList} {seqPath}')
  238. def copyImg(fromRootPath, toRootPath, phase):
  239. fromPath = osp.join(fromRootPath, 'images', phase)
  240. toPathSeqPath = osp.join(toRootPath, 'labels_with_ids', phase)
  241. seqList = sorted(glob.glob(toPathSeqPath + '/*'))
  242. for seqPath in seqList:
  243. seqName = seqPath.split('/')[-1]
  244. imgTxtList = sorted(glob.glob(osp.join(seqPath, 'img1') + '/*.txt'))
  245. img_toPathSeqPath = osp.join(seqPath, 'img1')
  246. img_toPathSeqPath = img_toPathSeqPath.replace('labels_with_ids',
  247. 'images')
  248. mkdir_if_missing(img_toPathSeqPath)
  249. for imgTxt in imgTxtList:
  250. imgName = imgTxt.split('/')[-1].replace('.txt', '.jpg')
  251. imgfromPath = osp.join(fromPath, seqName, imgName)
  252. print(f'cp {imgfromPath} {img_toPathSeqPath}')
  253. os.system(f'cp {imgfromPath} {img_toPathSeqPath}')
  254. if __name__ == "__main__":
  255. parser = argparse.ArgumentParser(description='BDD100K to MOT format')
  256. parser.add_argument("--data_path", default='bdd100k')
  257. parser.add_argument("--phase", default='train')
  258. parser.add_argument("--classes", default='2,3,4,9,10')
  259. parser.add_argument("--img_dir", default="bdd100k/images/track/")
  260. parser.add_argument("--label_dir", default="bdd100k/labels/box_track_20/")
  261. parser.add_argument("--save_path", default="bdd100kmot_vehicle")
  262. parser.add_argument("--height", default=720)
  263. parser.add_argument("--width", default=1280)
  264. args = parser.parse_args()
  265. attr_dict = dict()
  266. attr_dict["categories"] = [{
  267. "supercategory": "none",
  268. "id": 0,
  269. "name": "pedestrian"
  270. }, {
  271. "supercategory": "none",
  272. "id": 1,
  273. "name": "rider"
  274. }, {
  275. "supercategory": "none",
  276. "id": 2,
  277. "name": "car"
  278. }, {
  279. "supercategory": "none",
  280. "id": 3,
  281. "name": "truck"
  282. }, {
  283. "supercategory": "none",
  284. "id": 4,
  285. "name": "bus"
  286. }, {
  287. "supercategory": "none",
  288. "id": 5,
  289. "name": "train"
  290. }, {
  291. "supercategory": "none",
  292. "id": 6,
  293. "name": "motorcycle"
  294. }, {
  295. "supercategory": "none",
  296. "id": 7,
  297. "name": "bicycle"
  298. }, {
  299. "supercategory": "none",
  300. "id": 8,
  301. "name": "other person"
  302. }, {
  303. "supercategory": "none",
  304. "id": 9,
  305. "name": "trailer"
  306. }, {
  307. "supercategory": "none",
  308. "id": 10,
  309. "name": "other vehicle"
  310. }]
  311. attr_id_dict = {i['name']: i['id'] for i in attr_dict['categories']}
  312. # create bdd100kmot_vehicle training set in MOT format
  313. print('Loading and converting training set...')
  314. train_img_dir = os.path.join(args.img_dir, 'train')
  315. train_label_dir = os.path.join(args.label_dir, 'train')
  316. save_img_dir = os.path.join(args.save_path, 'images', 'train')
  317. save_label_dir = os.path.join(args.save_path, 'labels_with_ids', 'train')
  318. if not os.path.exists(save_img_dir): os.makedirs(save_img_dir)
  319. if not os.path.exists(save_label_dir): os.makedirs(save_label_dir)
  320. bdd2mot_tracking(train_img_dir, train_label_dir, save_img_dir,
  321. save_label_dir)
  322. # create bdd100kmot_vehicle validation set in MOT format
  323. print('Loading and converting validation set...')
  324. val_img_dir = os.path.join(args.img_dir, 'val')
  325. val_label_dir = os.path.join(args.label_dir, 'val')
  326. save_img_dir = os.path.join(args.save_path, 'images', 'val')
  327. save_label_dir = os.path.join(args.save_path, 'labels_with_ids', 'val')
  328. if not os.path.exists(save_img_dir): os.makedirs(save_img_dir)
  329. if not os.path.exists(save_label_dir): os.makedirs(save_label_dir)
  330. bdd2mot_tracking(val_img_dir, val_label_dir, save_img_dir, save_label_dir)
  331. # gen gt file
  332. dataPath = args.data_path
  333. phase = args.phase
  334. classes = args.classes.split(',')
  335. formatOrigin(osp.join(dataPath, 'bdd100kmot_vehicle'), phase)
  336. dataDir = osp.join(
  337. osp.join(dataPath, 'bdd100kmot_vehicle'), 'labels_with_ids',
  338. phase) + '/*'
  339. genMotGt(dataDir, classes=classes)
  340. copyImg(dataPath, osp.join(dataPath, 'bdd100kmot_vehicle'), phase)
  341. updateSeqInfo(osp.join(dataPath, 'bdd100kmot_vehicle'), phase)
  342. gen_image_list(osp.join(dataPath, 'bdd100kmot_vehicle'), phase)
  343. os.system(f'rm -r {dataPath}/bdd100kmot_vehicle/images/' + phase + '/*.jpg')