datasets.py 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Dataloaders and dataset utils
  4. """
  5. import glob
  6. import hashlib
  7. import json
  8. import math
  9. import os
  10. import random
  11. import shutil
  12. import time
  13. from itertools import repeat
  14. from multiprocessing.pool import Pool, ThreadPool
  15. from pathlib import Path
  16. from threading import Thread
  17. from urllib.parse import urlparse
  18. from zipfile import ZipFile
  19. import numpy as np
  20. import torch
  21. import torch.nn.functional as F
  22. import yaml
  23. from PIL import ExifTags, Image, ImageOps
  24. from torch.utils.data import DataLoader, Dataset, dataloader, distributed
  25. from tqdm.auto import tqdm
  26. from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
  27. from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
  28. cv2, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
  29. from utils.torch_utils import torch_distributed_zero_first
  30. # Parameters
  31. HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  32. IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp' # include image suffixes
  33. VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
  34. BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
  35. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  36. # Get orientation exif tag
  37. for orientation in ExifTags.TAGS.keys():
  38. if ExifTags.TAGS[orientation] == 'Orientation':
  39. break
  40. def get_hash(paths): # 返回文件列表的hash值
  41. # Returns a single hash value of a list of paths (files or dirs)
  42. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  43. h = hashlib.md5(str(size).encode()) # hash sizes
  44. h.update(''.join(paths).encode()) # hash paths
  45. return h.hexdigest() # return hash
  46. def exif_size(img): # 获取图片的宽、高信息
  47. # Returns exif-corrected PIL size
  48. s = img.size # (width, height)
  49. try:
  50. rotation = dict(img._getexif().items())[orientation]
  51. if rotation == 6: # rotation 270
  52. s = (s[1], s[0])
  53. elif rotation == 8: # rotation 90
  54. s = (s[1], s[0])
  55. except Exception:
  56. pass
  57. return s
  58. def exif_transpose(image):
  59. """
  60. Transpose a PIL image accordingly if it has an EXIF Orientation tag.
  61. Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
  62. :param image: The image to transpose.
  63. :return: An image.
  64. """
  65. exif = image.getexif()
  66. orientation = exif.get(0x0112, 1) # default 1
  67. if orientation > 1:
  68. method = {
  69. 2: Image.FLIP_LEFT_RIGHT,
  70. 3: Image.ROTATE_180,
  71. 4: Image.FLIP_TOP_BOTTOM,
  72. 5: Image.TRANSPOSE,
  73. 6: Image.ROTATE_270,
  74. 7: Image.TRANSVERSE,
  75. 8: Image.ROTATE_90,}.get(orientation)
  76. if method is not None:
  77. image = image.transpose(method)
  78. del exif[0x0112]
  79. image.info["exif"] = exif.tobytes()
  80. return image
  81. def create_dataloader(path,
  82. imgsz,
  83. batch_size,
  84. stride,
  85. single_cls=False,
  86. hyp=None,
  87. augment=False,
  88. cache=False,
  89. pad=0.0,
  90. rect=False,
  91. rank=-1,
  92. workers=8,
  93. image_weights=False,
  94. quad=False,
  95. prefix='',
  96. shuffle=False):
  97. if rect and shuffle:
  98. LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
  99. shuffle = False
  100. with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
  101. dataset = LoadImagesAndLabels(
  102. path,
  103. imgsz,
  104. batch_size,
  105. augment=augment, # augmentation
  106. hyp=hyp, # hyperparameters
  107. rect=rect, # rectangular batches
  108. cache_images=cache,
  109. single_cls=single_cls,
  110. stride=int(stride),
  111. pad=pad,
  112. image_weights=image_weights,
  113. prefix=prefix)
  114. batch_size = min(batch_size, len(dataset))
  115. nd = torch.cuda.device_count() # number of CUDA devices
  116. nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
  117. sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
  118. loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
  119. return loader(dataset,
  120. batch_size=batch_size,
  121. shuffle=shuffle and sampler is None,
  122. num_workers=nw,
  123. sampler=sampler,
  124. pin_memory=True,
  125. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
  126. class InfiniteDataLoader(dataloader.DataLoader):
  127. """ Dataloader that reuses workers
  128. Uses same syntax as vanilla DataLoader
  129. """
  130. def __init__(self, *args, **kwargs):
  131. super().__init__(*args, **kwargs)
  132. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  133. self.iterator = super().__iter__()
  134. def __len__(self):
  135. return len(self.batch_sampler.sampler)
  136. def __iter__(self):
  137. for i in range(len(self)):
  138. yield next(self.iterator)
  139. class _RepeatSampler:
  140. """ Sampler that repeats forever
  141. Args:
  142. sampler (Sampler)
  143. """
  144. def __init__(self, sampler):
  145. self.sampler = sampler
  146. def __iter__(self):
  147. while True:
  148. yield from iter(self.sampler)
  149. class LoadImages: # 定义迭代器 LoadImages, 用于detect.py
  150. # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
  151. def __init__(self, path, img_size=640, stride=32, auto=True):
  152. p = str(Path(path).resolve()) # os-agnostic absolute path 完整路径
  153. if '*' in p:
  154. files = sorted(glob.glob(p, recursive=True)) # glob
  155. elif os.path.isdir(p):
  156. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  157. elif os.path.isfile(p):
  158. files = [p] # files
  159. else:
  160. raise Exception(f'ERROR: {p} does not exist')
  161. # 分别提取图片和视频文件路径
  162. images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
  163. videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
  164. ni, nv = len(images), len(videos) # 获取图片与视频数量
  165. self.img_size = img_size #输入图片size
  166. self.stride = stride
  167. self.files = images + videos # 整合图片和视频路径到一个列表
  168. self.nf = ni + nv # number of files
  169. self.video_flag = [False] * ni + [True] * nv
  170. self.mode = 'image'
  171. self.auto = auto
  172. if any(videos): # 如果包含视频文件,则初始化opencv中的视频模块。
  173. self.new_video(videos[0]) # new video
  174. else:
  175. self.cap = None
  176. assert self.nf > 0, f'No images or videos found in {p}. ' \
  177. f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
  178. def __iter__(self):
  179. self.count = 0
  180. return self
  181. def __next__(self):
  182. if self.count == self.nf: # 如果self.count == self.nf 表示数据读取完了
  183. raise StopIteration
  184. path = self.files[self.count] # 获取文件路径
  185. if self.video_flag[self.count]: # 如果该文件为视频
  186. # Read video
  187. self.mode = 'video'
  188. ret_val, img0 = self.cap.read()# 获取当前帧画面。
  189. while not ret_val: # 如果ret_val为false 当前视频读取结束,打开下一个视频
  190. self.count += 1
  191. self.cap.release()
  192. if self.count == self.nf: # last video 如果为最后一个视频,则停止读取
  193. raise StopIteration
  194. else: # 读取下一个视频
  195. path = self.files[self.count]
  196. self.new_video(path)
  197. ret_val, img0 = self.cap.read()
  198. self.frame += 1
  199. s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
  200. else:
  201. # Read image
  202. self.count += 1
  203. img0 = cv2.imread(path) # BGR
  204. assert img0 is not None, f'Image Not Found {path}'
  205. s = f'image {self.count}/{self.nf} {path}: '
  206. # Padded resize 图像的缩放及填充
  207. img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
  208. # Convert
  209. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB #通道数转换
  210. img = np.ascontiguousarray(img) # 转换为连续内存
  211. return path, img, img0, self.cap, s
  212. def new_video(self, path):
  213. self.frame = 0 # 记录总帧数
  214. self.cap = cv2.VideoCapture(path)# 初始化视频对象
  215. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 视频总帧数
  216. def __len__(self):
  217. return self.nf # number of files
  218. class LoadWebcam: # for inference
  219. # YOLOv5 local webcam dataloader, i.e. `python detect.py --source 0`
  220. def __init__(self, pipe='0', img_size=640, stride=32):
  221. self.img_size = img_size
  222. self.stride = stride
  223. self.pipe = eval(pipe) if pipe.isnumeric() else pipe
  224. self.cap = cv2.VideoCapture(self.pipe) # video capture object
  225. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  226. def __iter__(self):
  227. self.count = -1
  228. return self
  229. def __next__(self):
  230. self.count += 1
  231. if cv2.waitKey(1) == ord('q'): # q to quit
  232. self.cap.release()
  233. cv2.destroyAllWindows()
  234. raise StopIteration
  235. # Read frame
  236. ret_val, img0 = self.cap.read()
  237. img0 = cv2.flip(img0, 1) # flip left-right
  238. # Print
  239. assert ret_val, f'Camera Error {self.pipe}'
  240. img_path = 'webcam.jpg'
  241. s = f'webcam {self.count}: '
  242. # Padded resize
  243. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  244. # Convert
  245. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  246. img = np.ascontiguousarray(img)
  247. return img_path, img, img0, None, s
  248. def __len__(self):
  249. return 0
  250. class LoadStreams: #从多IP读取流
  251. # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
  252. def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
  253. self.mode = 'stream'
  254. self.img_size = img_size
  255. self.stride = stride
  256. if os.path.isfile(sources):
  257. with open(sources) as f:
  258. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())] # 获取每一个视频流,保存为一个列表
  259. else:
  260. sources = [sources]
  261. n = len(sources)
  262. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  263. self.sources = [clean_str(x) for x in sources] # clean source names for later 视频流
  264. self.auto = auto
  265. for i, s in enumerate(sources): # index, source
  266. # Start thread to read frames from video stream
  267. st = f'{i + 1}/{n}: {s}... '
  268. if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video 判断是否为youtube视频
  269. check_requirements(('pafy', 'youtube_dl==2020.12.2'))
  270. import pafy
  271. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  272. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam 是否需要打开摄像头
  273. cap = cv2.VideoCapture(s)
  274. assert cap.isOpened(), f'{st}Failed to open {s}'
  275. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))# 视频宽度信息
  276. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))# 视频高度信息
  277. fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan 帧率
  278. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  279. self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
  280. _, self.imgs[i] = cap.read() # guarantee first frame 获取当前帧
  281. self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True) # 创建多线程读取视频流,daemon表示主线程结束时子线程也结束
  282. LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
  283. self.threads[i].start()
  284. LOGGER.info('') # newline
  285. # check for common shapes
  286. s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])# 获取进行resize+pad之后的shape, letterbox函数默认是按照矩形推理形状进行填充
  287. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  288. if not self.rect:
  289. LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
  290. def update(self, i, cap, stream):
  291. # Read stream `i` frames in daemon thread
  292. n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
  293. while cap.isOpened() and n < f:
  294. n += 1
  295. # _, self.imgs[index] = cap.read()
  296. cap.grab()
  297. if n % read == 0:
  298. success, im = cap.retrieve()
  299. if success:
  300. self.imgs[i] = im
  301. else:
  302. LOGGER.warning('WARNING: Video stream unresponsive, please check your IP camera connection.')
  303. self.imgs[i] = np.zeros_like(self.imgs[i])
  304. cap.open(stream) # re-open stream if signal was lost
  305. time.sleep(1 / self.fps[i]) # wait time
  306. def __iter__(self):
  307. self.count = -1
  308. return self
  309. def __next__(self):
  310. self.count += 1
  311. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  312. cv2.destroyAllWindows()
  313. raise StopIteration
  314. # Letterbox
  315. img0 = self.imgs.copy()
  316. img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0] #图片缩放和填充
  317. # Stack
  318. img = np.stack(img, 0) # 将读取的图片拼接到一起
  319. # Convert
  320. img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW 通道转换
  321. img = np.ascontiguousarray(img)
  322. return self.sources, img, img0, None, ''
  323. def __len__(self):
  324. return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
  325. def img2label_paths(img_paths):
  326. # Define label paths as a function of image paths
  327. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  328. return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
  329. # 自定义数据集
  330. class LoadImagesAndLabels(Dataset):
  331. # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
  332. cache_version = 0.6 # dataset labels *.cache version
  333. def __init__(self,
  334. path,
  335. img_size=640,
  336. batch_size=16,
  337. augment=False,
  338. hyp=None,
  339. rect=False,
  340. image_weights=False,
  341. cache_images=False,
  342. single_cls=False,
  343. stride=32,
  344. pad=0.0,
  345. prefix=''):
  346. self.img_size = img_size
  347. self.augment = augment # 数据增强
  348. self.hyp = hyp # 超参数
  349. self.image_weights = image_weights # 图片采样权重
  350. self.rect = False if image_weights else rect # 矩形训练
  351. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training) mosaic数据增强
  352. self.mosaic_border = [-img_size // 2, -img_size // 2] # mosaic增强的边界值
  353. self.stride = stride # 模型下采样的步长
  354. self.path = path
  355. self.albumentations = Albumentations() if augment else None
  356. try:
  357. f = [] # image files
  358. for p in path if isinstance(path, list) else [path]: # 获取数据集路径path,包含图片路径的txt文件或者包含图片文件的文件夹路径
  359. p = Path(p) # os-agnostic
  360. if p.is_dir(): # dir
  361. f += glob.glob(str(p / '**' / '*.*'), recursive=True) # 遍历该文件夹下的所有文件
  362. # f = list(p.rglob('*.*')) # pathlib
  363. elif p.is_file(): # file
  364. with open(p) as t:
  365. t = t.read().strip().splitlines()
  366. parent = str(p.parent) + os.sep # 获取数据集路径的上级父目录, os.sep为路径里的分隔符
  367. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  368. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  369. else:
  370. raise Exception(f'{prefix}{p} does not exist')
  371. self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS) # 分隔符替换为os.sep,os.path.splitext(x)将文件名与扩展命分开并返回一个列表
  372. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
  373. assert self.im_files, f'{prefix}No images found'
  374. except Exception as e:
  375. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
  376. # Check cache
  377. self.label_files = img2label_paths(self.im_files) # labels 将图片路径转换为标签路径
  378. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # 缓存文件
  379. try:
  380. cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict 缓存文件存在
  381. assert cache['version'] == self.cache_version # same version
  382. assert cache['hash'] == get_hash(self.label_files + self.im_files) # same hash
  383. except Exception:
  384. cache, exists = self.cache_labels(cache_path, prefix), False # cache 缓存文件不存在
  385. # Display cache
  386. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
  387. if exists and LOCAL_RANK in (-1, 0):
  388. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
  389. tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
  390. if cache['msgs']:
  391. LOGGER.info('\n'.join(cache['msgs'])) # display warnings
  392. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
  393. # Read cache
  394. [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
  395. labels, shapes, self.segments = zip(*cache.values()) #解压
  396. self.labels = list(labels)
  397. self.shapes = np.array(shapes, dtype=np.float64)
  398. self.im_files = list(cache.keys()) # update 图片文件列表
  399. self.label_files = img2label_paths(cache.keys()) # update 标签文件列表
  400. n = len(shapes) # number of images
  401. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index 设置batch索引
  402. nb = bi[-1] + 1 # number of batches 一个epoch中batch的数量
  403. self.batch = bi # batch index of image
  404. self.n = n
  405. self.indices = range(n)
  406. # Update labels 过滤掉多余的标签
  407. include_class = [] # filter labels to include only these classes (optional)
  408. include_class_array = np.array(include_class).reshape(1, -1)
  409. for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
  410. if include_class:
  411. j = (label[:, 0:1] == include_class_array).any(1)
  412. self.labels[i] = label[j]
  413. if segment:
  414. self.segments[i] = segment[j]
  415. if single_cls: # single-class training, merge all classes into 0
  416. self.labels[i][:, 0] = 0
  417. if segment:
  418. self.segments[i][:, 0] = 0
  419. # Rectangular Training ar排序矩阵训练
  420. if self.rect:
  421. # Sort by aspect ratio
  422. s = self.shapes # wh
  423. ar = s[:, 1] / s[:, 0] # aspect ratio
  424. irect = ar.argsort() # 获取根据ar从小到大排列的索引
  425. # 根据索引排序数据集与标签路径、shape、h/w
  426. self.im_files = [self.im_files[i] for i in irect]
  427. self.label_files = [self.label_files[i] for i in irect]
  428. self.labels = [self.labels[i] for i in irect]
  429. self.shapes = s[irect] # wh
  430. ar = ar[irect]
  431. # Set training image shapes
  432. shapes = [[1, 1]] * nb
  433. for i in range(nb):
  434. ari = ar[bi == i]
  435. mini, maxi = ari.min(), ari.max()
  436. if maxi < 1:
  437. shapes[i] = [maxi, 1]
  438. elif mini > 1:
  439. shapes[i] = [1, 1 / mini]
  440. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  441. # Cache images into RAM/disk for faster training (WARNING: large datasets may exceed system resources)
  442. self.ims = [None] * n
  443. self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
  444. if cache_images:
  445. gb = 0 # Gigabytes of cached images
  446. self.im_hw0, self.im_hw = [None] * n, [None] * n
  447. fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
  448. results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
  449. pbar = tqdm(enumerate(results), total=n, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0)
  450. for i, x in pbar:
  451. if cache_images == 'disk':
  452. gb += self.npy_files[i].stat().st_size
  453. else: # 'ram'
  454. self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
  455. gb += self.ims[i].nbytes
  456. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
  457. pbar.close()
  458. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  459. # Cache dataset labels, check images and read shapes
  460. x = {} # dict
  461. nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
  462. desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
  463. with Pool(NUM_THREADS) as pool:
  464. pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
  465. desc=desc,
  466. total=len(self.im_files),
  467. bar_format=BAR_FORMAT)
  468. for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
  469. nm += nm_f
  470. nf += nf_f
  471. ne += ne_f
  472. nc += nc_f
  473. if im_file:
  474. x[im_file] = [lb, shape, segments]
  475. if msg:
  476. msgs.append(msg)
  477. pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
  478. pbar.close()
  479. if msgs:
  480. LOGGER.info('\n'.join(msgs))
  481. if nf == 0:
  482. LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
  483. x['hash'] = get_hash(self.label_files + self.im_files)
  484. x['results'] = nf, nm, ne, nc, len(self.im_files)
  485. x['msgs'] = msgs # warnings
  486. x['version'] = self.cache_version # cache version
  487. try:
  488. np.save(path, x) # save cache for next time
  489. path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
  490. LOGGER.info(f'{prefix}New cache created: {path}')
  491. except Exception as e:
  492. LOGGER.warning(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # not writeable
  493. return x
  494. def __len__(self):
  495. return len(self.im_files)
  496. # def __iter__(self):
  497. # self.count = -1
  498. # print('ran dataset iter')
  499. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  500. # return self
  501. def __getitem__(self, index):
  502. index = self.indices[index] # linear, shuffled, or image_weights
  503. hyp = self.hyp
  504. mosaic = self.mosaic and random.random() < hyp['mosaic']
  505. if mosaic: # mosaic 数据增强
  506. # Load mosaic
  507. rand = random.randint(0,10)
  508. if rand > 5:
  509. img, labels = self.load_mosaic(index)
  510. else:
  511. img, labels = self.load_mosaic9(index)
  512. shapes = None
  513. # MixUp augmentation MixUp数据增强
  514. if random.random() < hyp['mixup']:
  515. img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
  516. else:
  517. # Load image
  518. img, (h0, w0), (h, w) = self.load_image(index)
  519. # Letterbox
  520. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape 获取letterboxed后的图片大小
  521. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment) # 填充图片
  522. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  523. # 读取labels
  524. labels = self.labels[index].copy()
  525. if labels.size: # normalized xywh to pixel xyxy format
  526. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1]) # yolo格式转voc
  527. if self.augment:
  528. img, labels = random_perspective(img,
  529. labels,
  530. degrees=hyp['degrees'],
  531. translate=hyp['translate'],
  532. scale=hyp['scale'],
  533. shear=hyp['shear'],
  534. perspective=hyp['perspective']) #随机对图片进行旋转,平移,缩放,裁剪
  535. nl = len(labels) # number of labels
  536. if nl:
  537. labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3) # yolo格式转voc
  538. if self.augment:
  539. # Albumentations
  540. img, labels = self.albumentations(img, labels)
  541. nl = len(labels) # update after albumentations
  542. # HSV color-space HSV 颜色空间数据增强
  543. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  544. # Flip up-down 随机上下翻转数据增强
  545. if random.random() < hyp['flipud']:
  546. img = np.flipud(img)
  547. if nl:
  548. labels[:, 2] = 1 - labels[:, 2]
  549. # Flip left-right 随机左右翻转数据增强
  550. if random.random() < hyp['fliplr']:
  551. img = np.fliplr(img)
  552. if nl:
  553. labels[:, 1] = 1 - labels[:, 1]
  554. # Cutouts
  555. # labels = cutout(img, labels, p=0.5)
  556. # nl = len(labels) # update after cutout
  557. labels_out = torch.zeros((nl, 6))
  558. if nl:
  559. labels_out[:, 1:] = torch.from_numpy(labels)
  560. # Convert
  561. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB 通道翻转
  562. img = np.ascontiguousarray(img) # 将img变成内存连续的数据
  563. return torch.from_numpy(img), labels_out, self.im_files[index], shapes
  564. def load_image(self, i):
  565. # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
  566. im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
  567. if im is None: # not cached in RAM
  568. if fn.exists(): # load npy
  569. im = np.load(fn)
  570. else: # read image
  571. im = cv2.imread(f) # BGR
  572. assert im is not None, f'Image Not Found {f}'
  573. h0, w0 = im.shape[:2] # orig hw
  574. r = self.img_size / max(h0, w0) # ratio
  575. if r != 1: # if sizes are not equal
  576. im = cv2.resize(im, (int(w0 * r), int(h0 * r)),
  577. interpolation=cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA)
  578. return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
  579. else:
  580. return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
  581. def cache_images_to_disk(self, i):
  582. # Saves an image as an *.npy file for faster loading
  583. f = self.npy_files[i]
  584. if not f.exists():
  585. np.save(f.as_posix(), cv2.imread(self.im_files[i]))
  586. def load_mosaic(self, index):
  587. # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
  588. labels4, segments4 = [], []
  589. s = self.img_size
  590. yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y 随机取mosaic中心点
  591. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices 随机获取其他三张图片的索引
  592. random.shuffle(indices)
  593. for i, index in enumerate(indices):
  594. # Load image
  595. img, _, (h, w) = self.load_image(index) #加载图片并根据设定的输入大小与图片原大小的比例ratio进行resize
  596. # place img in img4
  597. if i == 0: # top left
  598. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles 初始化大图,将大图扩展为原来的2倍
  599. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image) 设置大图左上角的位置
  600. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image) 选取小图上的位置
  601. elif i == 1: # top right
  602. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  603. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  604. elif i == 2: # bottom left
  605. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  606. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  607. elif i == 3: # bottom right
  608. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  609. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  610. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  611. padw = x1a - x1b # 计算小图到大图上时所产生的偏移,用来计算mosaic增强后的标签框的位置
  612. padh = y1a - y1b
  613. # Labels
  614. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  615. if labels.size: #重新调整标签框的位置
  616. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format 将yolo格式转换为voc格式
  617. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  618. labels4.append(labels)
  619. segments4.extend(segments)
  620. # Concat/clip labels 调整标签框在图片内部
  621. labels4 = np.concatenate(labels4, 0)
  622. for x in (labels4[:, 1:], *segments4):
  623. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  624. # img4, labels4 = replicate(img4, labels4) # replicate
  625. # Augment
  626. img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste']) # 复制粘贴数据增强
  627. #进行mosaic的时候将四张图片整合到一起之后shape为[2*img_size, 2*img_size]
  628. #对mosaic整合的图片进行随机旋转、平移、缩放、裁剪,并reszie为输入大小 img_size
  629. img4, labels4 = random_perspective(img4,
  630. labels4,
  631. segments4,
  632. degrees=self.hyp['degrees'],
  633. translate=self.hyp['translate'],
  634. scale=self.hyp['scale'],
  635. shear=self.hyp['shear'],
  636. perspective=self.hyp['perspective'],
  637. border=self.mosaic_border) # border to remove
  638. return img4, labels4
  639. def load_mosaic9(self, index):
  640. # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
  641. labels9, segments9 = [], []
  642. s = self.img_size #图像大小(640)
  643. indices = [index] + random.choices(self.indices, k=8) # 随机选取8张图片的索引
  644. random.shuffle(indices)
  645. hp, wp = -1, -1 # height, width previous
  646. for i, index in enumerate(indices):
  647. # Load image
  648. img, _, (h, w) = self.load_image(index)
  649. # place img in img9
  650. if i == 0: # center
  651. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # 初始化大图,将大图扩展为原来的3倍
  652. h0, w0 = h, w
  653. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  654. elif i == 1: # top
  655. c = s, s - h, s + w, s
  656. elif i == 2: # top right
  657. c = s + wp, s - h, s + wp + w, s
  658. elif i == 3: # right
  659. c = s + w0, s, s + w0 + w, s + h
  660. elif i == 4: # bottom right
  661. c = s + w0, s + hp, s + w0 + w, s + hp + h
  662. elif i == 5: # bottom
  663. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  664. elif i == 6: # bottom left
  665. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  666. elif i == 7: # left
  667. c = s - w, s + h0 - h, s, s + h0
  668. elif i == 8: # top left
  669. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  670. padx, pady = c[:2]
  671. x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
  672. # Labels
  673. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  674. if labels.size:
  675. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  676. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  677. labels9.append(labels)
  678. segments9.extend(segments)
  679. # Image
  680. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  681. hp, wp = h, w # height, width previous
  682. # Offset
  683. yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
  684. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  685. # Concat/clip labels
  686. labels9 = np.concatenate(labels9, 0)
  687. labels9[:, [1, 3]] -= xc
  688. labels9[:, [2, 4]] -= yc
  689. c = np.array([xc, yc]) # centers
  690. segments9 = [x - c for x in segments9]
  691. for x in (labels9[:, 1:], *segments9):
  692. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  693. # img9, labels9 = replicate(img9, labels9) # replicate
  694. # Augment
  695. img9, labels9 = random_perspective(img9,
  696. labels9,
  697. segments9,
  698. degrees=self.hyp['degrees'],
  699. translate=self.hyp['translate'],
  700. scale=self.hyp['scale'],
  701. shear=self.hyp['shear'],
  702. perspective=self.hyp['perspective'],
  703. border=self.mosaic_border) # border to remove
  704. return img9, labels9
  705. def load_mosaic16(self, index):
  706. # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
  707. labels16, segments16 = [], []
  708. s = self.img_size
  709. indices = [index] + random.choices(self.indices, k=15) # 8 additional image indices
  710. random.shuffle(indices)
  711. hp, wp = -1, -1 # height, width previous
  712. for i, index in enumerate(indices):
  713. # Load image
  714. img, _, (h, w) = self.load_image(index)
  715. # place img in img9
  716. if i == 0: # center
  717. img16 = np.full((s * 4, s * 4, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  718. h0, w0 = h, w
  719. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  720. elif i == 1: # top
  721. c = s, s - h, s + w, s
  722. elif i == 2: # top right
  723. c = s + wp, s - h, s + wp + w, s
  724. elif i == 3: # right
  725. c = s + w0, s, s + w0 + w, s + h
  726. elif i == 4: # bottom right
  727. c = s + w0, s + hp, s + w0 + w, s + hp + h
  728. elif i == 5: # bottom
  729. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  730. elif i == 6: # bottom left
  731. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  732. elif i == 7: # left
  733. c = s - w, s + h0 - h, s, s + h0
  734. elif i == 8: # top left
  735. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  736. padx, pady = c[:2]
  737. x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
  738. # Labels
  739. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  740. if labels.size:
  741. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  742. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  743. labels16.append(labels)
  744. segments16.extend(segments)
  745. # Image
  746. img16[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  747. hp, wp = h, w # height, width previous
  748. # Offset
  749. yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
  750. img16 = img16[yc:yc + 2 * s, xc:xc + 2 * s]
  751. # Concat/clip labels
  752. labels16 = np.concatenate(labels16, 0)
  753. labels16[:, [1, 3]] -= xc
  754. labels16[:, [2, 4]] -= yc
  755. c = np.array([xc, yc]) # centers
  756. segments9 = [x - c for x in segments16]
  757. for x in (labels16[:, 1:], *segments16):
  758. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  759. # img9, labels9 = replicate(img9, labels9) # replicate
  760. # Augment
  761. img16, labels16 = random_perspective(img16,
  762. labels16,
  763. segments16,
  764. degrees=self.hyp['degrees'],
  765. translate=self.hyp['translate'],
  766. scale=self.hyp['scale'],
  767. shear=self.hyp['shear'],
  768. perspective=self.hyp['perspective'],
  769. border=self.mosaic_border) # border to remove
  770. return img16, labels16
  771. @staticmethod
  772. def collate_fn(batch):
  773. im, label, path, shapes = zip(*batch) # transposed
  774. for i, lb in enumerate(label):
  775. lb[:, 0] = i # add target image index for build_targets()
  776. return torch.stack(im, 0), torch.cat(label, 0), path, shapes
  777. @staticmethod
  778. def collate_fn4(batch):
  779. img, label, path, shapes = zip(*batch) # transposed
  780. n = len(shapes) // 4
  781. im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  782. ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
  783. wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
  784. s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
  785. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  786. i *= 4
  787. if random.random() < 0.5:
  788. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
  789. align_corners=False)[0].type(img[i].type())
  790. lb = label[i]
  791. else:
  792. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  793. lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  794. im4.append(im)
  795. label4.append(lb)
  796. for i, lb in enumerate(label4):
  797. lb[:, 0] = i # add target image index for build_targets()
  798. return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
  799. # Ancillary functions --------------------------------------------------------------------------------------------------
  800. def create_folder(path='./new'):
  801. # Create folder
  802. if os.path.exists(path):
  803. shutil.rmtree(path) # delete output folder
  804. os.makedirs(path) # make new output folder
  805. def flatten_recursive(path=DATASETS_DIR / 'coco128'):
  806. # Flatten a recursive directory by bringing all files to top level
  807. new_path = Path(str(path) + '_flat')
  808. create_folder(new_path)
  809. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  810. shutil.copyfile(file, new_path / Path(file).name)
  811. def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.datasets import *; extract_boxes()
  812. # Convert detection dataset into classification dataset, with one directory per class
  813. path = Path(path) # images dir
  814. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  815. files = list(path.rglob('*.*'))
  816. n = len(files) # number of files
  817. for im_file in tqdm(files, total=n):
  818. if im_file.suffix[1:] in IMG_FORMATS:
  819. # image
  820. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  821. h, w = im.shape[:2]
  822. # labels
  823. lb_file = Path(img2label_paths([str(im_file)])[0])
  824. if Path(lb_file).exists():
  825. with open(lb_file) as f:
  826. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  827. for j, x in enumerate(lb):
  828. c = int(x[0]) # class
  829. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  830. if not f.parent.is_dir():
  831. f.parent.mkdir(parents=True)
  832. b = x[1:] * [w, h, w, h] # box
  833. # b[2:] = b[2:].max() # rectangle to square
  834. b[2:] = b[2:] * 1.2 + 3 # pad
  835. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  836. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  837. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  838. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  839. def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
  840. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  841. Usage: from utils.datasets import *; autosplit()
  842. Arguments
  843. path: Path to images directory
  844. weights: Train, val, test weights (list, tuple)
  845. annotated_only: Only use images with an annotated txt file
  846. """
  847. path = Path(path) # images dir
  848. files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
  849. n = len(files) # number of files
  850. random.seed(0) # for reproducibility
  851. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  852. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  853. [(path.parent / x).unlink(missing_ok=True) for x in txt] # remove existing
  854. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  855. for i, img in tqdm(zip(indices, files), total=n):
  856. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  857. with open(path.parent / txt[i], 'a') as f:
  858. f.write('./' + img.relative_to(path.parent).as_posix() + '\n') # add image to txt file
  859. def verify_image_label(args):
  860. # Verify one image-label pair
  861. im_file, lb_file, prefix = args
  862. nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
  863. try:
  864. # verify images
  865. im = Image.open(im_file)
  866. im.verify() # PIL verify
  867. shape = exif_size(im) # image size
  868. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  869. assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
  870. if im.format.lower() in ('jpg', 'jpeg'):
  871. with open(im_file, 'rb') as f:
  872. f.seek(-2, 2)
  873. if f.read() != b'\xff\xd9': # corrupt JPEG
  874. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
  875. msg = f'{prefix}WARNING: {im_file}: corrupt JPEG restored and saved'
  876. # verify labels
  877. if os.path.isfile(lb_file):
  878. nf = 1 # label found
  879. with open(lb_file) as f:
  880. lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
  881. if any(len(x) > 6 for x in lb): # is segment
  882. classes = np.array([x[0] for x in lb], dtype=np.float32)
  883. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
  884. lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  885. lb = np.array(lb, dtype=np.float32)
  886. nl = len(lb)
  887. if nl:
  888. assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
  889. assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
  890. assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
  891. _, i = np.unique(lb, axis=0, return_index=True)
  892. if len(i) < nl: # duplicate row check
  893. lb = lb[i] # remove duplicates
  894. if segments:
  895. segments = segments[i]
  896. msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
  897. else:
  898. ne = 1 # label empty
  899. lb = np.zeros((0, 5), dtype=np.float32)
  900. else:
  901. nm = 1 # label missing
  902. lb = np.zeros((0, 5), dtype=np.float32)
  903. return im_file, lb, shape, segments, nm, nf, ne, nc, msg
  904. except Exception as e:
  905. nc = 1
  906. msg = f'{prefix}WARNING: {im_file}: ignoring corrupt image/label: {e}'
  907. return [None, None, None, None, nm, nf, ne, nc, msg]
  908. def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profile=False, hub=False):
  909. """ Return dataset statistics dictionary with images and instances counts per split per class
  910. To run in parent directory: export PYTHONPATH="$PWD/yolov5"
  911. Usage1: from utils.datasets import *; dataset_stats('coco128.yaml', autodownload=True)
  912. Usage2: from utils.datasets import *; dataset_stats('path/to/coco128_with_yaml.zip')
  913. Arguments
  914. path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
  915. autodownload: Attempt to download dataset if not found locally
  916. verbose: Print stats dictionary
  917. """
  918. def round_labels(labels):
  919. # Update labels to integer class and 6 decimal place floats
  920. return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
  921. def unzip(path):
  922. # Unzip data.zip TODO: CONSTRAINT: path/to/abc.zip MUST unzip to 'path/to/abc/'
  923. if str(path).endswith('.zip'): # path is data.zip
  924. assert Path(path).is_file(), f'Error unzipping {path}, file not found'
  925. ZipFile(path).extractall(path=path.parent) # unzip
  926. dir = path.with_suffix('') # dataset directory == zip name
  927. return True, str(dir), next(dir.rglob('*.yaml')) # zipped, data_dir, yaml_path
  928. else: # path is data.yaml
  929. return False, None, path
  930. def hub_ops(f, max_dim=1920):
  931. # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
  932. f_new = im_dir / Path(f).name # dataset-hub image filename
  933. try: # use PIL
  934. im = Image.open(f)
  935. r = max_dim / max(im.height, im.width) # ratio
  936. if r < 1.0: # image too large
  937. im = im.resize((int(im.width * r), int(im.height * r)))
  938. im.save(f_new, 'JPEG', quality=75, optimize=True) # save
  939. except Exception as e: # use OpenCV
  940. print(f'WARNING: HUB ops PIL failure {f}: {e}')
  941. im = cv2.imread(f)
  942. im_height, im_width = im.shape[:2]
  943. r = max_dim / max(im_height, im_width) # ratio
  944. if r < 1.0: # image too large
  945. im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
  946. cv2.imwrite(str(f_new), im)
  947. zipped, data_dir, yaml_path = unzip(Path(path))
  948. with open(check_yaml(yaml_path), errors='ignore') as f:
  949. data = yaml.safe_load(f) # data dict
  950. if zipped:
  951. data['path'] = data_dir # TODO: should this be dir.resolve()?
  952. check_dataset(data, autodownload) # download dataset if missing
  953. hub_dir = Path(data['path'] + ('-hub' if hub else ''))
  954. stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
  955. for split in 'train', 'val', 'test':
  956. if data.get(split) is None:
  957. stats[split] = None # i.e. no test set
  958. continue
  959. x = []
  960. dataset = LoadImagesAndLabels(data[split]) # load dataset
  961. for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
  962. x.append(np.bincount(label[:, 0].astype(int), minlength=data['nc']))
  963. x = np.array(x) # shape(128x80)
  964. stats[split] = {
  965. 'instance_stats': {
  966. 'total': int(x.sum()),
  967. 'per_class': x.sum(0).tolist()},
  968. 'image_stats': {
  969. 'total': dataset.n,
  970. 'unlabelled': int(np.all(x == 0, 1).sum()),
  971. 'per_class': (x > 0).sum(0).tolist()},
  972. 'labels': [{
  973. str(Path(k).name): round_labels(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
  974. if hub:
  975. im_dir = hub_dir / 'images'
  976. im_dir.mkdir(parents=True, exist_ok=True)
  977. for _ in tqdm(ThreadPool(NUM_THREADS).imap(hub_ops, dataset.im_files), total=dataset.n, desc='HUB Ops'):
  978. pass
  979. # Profile
  980. stats_path = hub_dir / 'stats.json'
  981. if profile:
  982. for _ in range(1):
  983. file = stats_path.with_suffix('.npy')
  984. t1 = time.time()
  985. np.save(file, stats)
  986. t2 = time.time()
  987. x = np.load(file, allow_pickle=True)
  988. print(f'stats.npy times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  989. file = stats_path.with_suffix('.json')
  990. t1 = time.time()
  991. with open(file, 'w') as f:
  992. json.dump(stats, f) # save stats *.json
  993. t2 = time.time()
  994. with open(file) as f:
  995. x = json.load(f) # load hyps dict
  996. print(f'stats.json times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  997. # Save, print and return
  998. if hub:
  999. print(f'Saving {stats_path.resolve()}...')
  1000. with open(stats_path, 'w') as f:
  1001. json.dump(stats, f) # save stats.json
  1002. if verbose:
  1003. print(json.dumps(stats, indent=2, sort_keys=False))
  1004. return stats