dataloaders.py 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Dataloaders and dataset utils
  4. """
  5. import glob
  6. import hashlib
  7. import json
  8. import math
  9. import os
  10. import random
  11. import shutil
  12. import time
  13. from itertools import repeat
  14. from multiprocessing.pool import Pool, ThreadPool
  15. from pathlib import Path
  16. from threading import Thread
  17. from urllib.parse import urlparse
  18. from zipfile import ZipFile
  19. import numpy as np
  20. import torch
  21. import torch.nn.functional as F
  22. import yaml
  23. from PIL import ExifTags, Image, ImageOps
  24. from torch.utils.data import DataLoader, Dataset, dataloader, distributed
  25. from tqdm import tqdm
  26. from dependence.yolov5.utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, \
  27. random_perspective
  28. from dependence.yolov5.utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
  29. cv2, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
  30. from dependence.yolov5.utils.torch_utils import torch_distributed_zero_first
  31. # Parameters
  32. HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  33. IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp' # include image suffixes
  34. VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
  35. BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
  36. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  37. # Get orientation exif tag
  38. for orientation in ExifTags.TAGS.keys():
  39. if ExifTags.TAGS[orientation] == 'Orientation':
  40. break
  41. def get_hash(paths):
  42. # Returns a single hash value of a list of paths (files or dirs)
  43. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  44. h = hashlib.md5(str(size).encode()) # hash sizes
  45. h.update(''.join(paths).encode()) # hash paths
  46. return h.hexdigest() # return hash
  47. def exif_size(img):
  48. # Returns exif-corrected PIL size
  49. s = img.size # (width, height)
  50. try:
  51. rotation = dict(img._getexif().items())[orientation]
  52. if rotation in [6, 8]: # rotation 270 or 90
  53. s = (s[1], s[0])
  54. except Exception:
  55. pass
  56. return s
  57. def exif_transpose(image):
  58. """
  59. Transpose a PIL image accordingly if it has an EXIF Orientation tag.
  60. Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
  61. :param image: The image to transpose.
  62. :return: An image.
  63. """
  64. exif = image.getexif()
  65. orientation = exif.get(0x0112, 1) # default 1
  66. if orientation > 1:
  67. method = {
  68. 2: Image.FLIP_LEFT_RIGHT,
  69. 3: Image.ROTATE_180,
  70. 4: Image.FLIP_TOP_BOTTOM,
  71. 5: Image.TRANSPOSE,
  72. 6: Image.ROTATE_270,
  73. 7: Image.TRANSVERSE,
  74. 8: Image.ROTATE_90,}.get(orientation)
  75. if method is not None:
  76. image = image.transpose(method)
  77. del exif[0x0112]
  78. image.info["exif"] = exif.tobytes()
  79. return image
  80. def create_dataloader(path,
  81. imgsz,
  82. batch_size,
  83. stride,
  84. single_cls=False,
  85. hyp=None,
  86. augment=False,
  87. cache=False,
  88. pad=0.0,
  89. rect=False,
  90. rank=-1,
  91. workers=8,
  92. image_weights=False,
  93. quad=False,
  94. prefix='',
  95. shuffle=False):
  96. if rect and shuffle:
  97. LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
  98. shuffle = False
  99. with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
  100. dataset = LoadImagesAndLabels(
  101. path,
  102. imgsz,
  103. batch_size,
  104. augment=augment, # augmentation
  105. hyp=hyp, # hyperparameters
  106. rect=rect, # rectangular batches
  107. cache_images=cache,
  108. single_cls=single_cls,
  109. stride=int(stride),
  110. pad=pad,
  111. image_weights=image_weights,
  112. prefix=prefix)
  113. batch_size = min(batch_size, len(dataset))
  114. nd = torch.cuda.device_count() # number of CUDA devices
  115. nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
  116. sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
  117. loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
  118. return loader(dataset,
  119. batch_size=batch_size,
  120. shuffle=shuffle and sampler is None,
  121. num_workers=nw,
  122. sampler=sampler,
  123. pin_memory=True,
  124. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
  125. class InfiniteDataLoader(dataloader.DataLoader):
  126. """ Dataloader that reuses workers
  127. Uses same syntax as vanilla DataLoader
  128. """
  129. def __init__(self, *args, **kwargs):
  130. super().__init__(*args, **kwargs)
  131. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  132. self.iterator = super().__iter__()
  133. def __len__(self):
  134. return len(self.batch_sampler.sampler)
  135. def __iter__(self):
  136. for _ in range(len(self)):
  137. yield next(self.iterator)
  138. class _RepeatSampler:
  139. """ Sampler that repeats forever
  140. Args:
  141. sampler (Sampler)
  142. """
  143. def __init__(self, sampler):
  144. self.sampler = sampler
  145. def __iter__(self):
  146. while True:
  147. yield from iter(self.sampler)
  148. class LoadImages:
  149. # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
  150. def __init__(self, path, img_size=640, stride=32, auto=True):
  151. p = str(Path(path).resolve()) # os-agnostic absolute path
  152. if '*' in p:
  153. files = sorted(glob.glob(p, recursive=True)) # glob
  154. elif os.path.isdir(p):
  155. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  156. elif os.path.isfile(p):
  157. files = [p] # files
  158. else:
  159. raise Exception(f'ERROR: {p} does not exist')
  160. images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
  161. videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
  162. ni, nv = len(images), len(videos)
  163. self.img_size = img_size
  164. self.stride = stride
  165. self.files = images + videos
  166. self.nf = ni + nv # number of files
  167. self.video_flag = [False] * ni + [True] * nv
  168. self.mode = 'image'
  169. self.auto = auto
  170. if any(videos):
  171. self.new_video(videos[0]) # new video
  172. else:
  173. self.cap = None
  174. assert self.nf > 0, f'No images or videos found in {p}. ' \
  175. f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
  176. def __iter__(self):
  177. self.count = 0
  178. return self
  179. def __next__(self):
  180. if self.count == self.nf:
  181. raise StopIteration
  182. path = self.files[self.count]
  183. if self.video_flag[self.count]:
  184. # Read video
  185. self.mode = 'video'
  186. ret_val, img0 = self.cap.read()
  187. while not ret_val:
  188. self.count += 1
  189. self.cap.release()
  190. if self.count == self.nf: # last video
  191. raise StopIteration
  192. path = self.files[self.count]
  193. self.new_video(path)
  194. ret_val, img0 = self.cap.read()
  195. self.frame += 1
  196. s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
  197. else:
  198. # Read image
  199. self.count += 1
  200. img0 = cv2.imread(path) # BGR
  201. assert img0 is not None, f'Image Not Found {path}'
  202. s = f'image {self.count}/{self.nf} {path}: '
  203. # Padded resize
  204. img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
  205. # Convert
  206. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  207. img = np.ascontiguousarray(img)
  208. return path, img, img0, self.cap, s
  209. def new_video(self, path):
  210. self.frame = 0
  211. self.cap = cv2.VideoCapture(path)
  212. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  213. def __len__(self):
  214. return self.nf # number of files
  215. class LoadWebcam: # for inference
  216. # YOLOv5 local webcam dataloader, i.e. `python detect.py --source 0`
  217. def __init__(self, pipe='0', img_size=640, stride=32):
  218. self.img_size = img_size
  219. self.stride = stride
  220. self.pipe = eval(pipe) if pipe.isnumeric() else pipe
  221. self.cap = cv2.VideoCapture(self.pipe) # video capture object
  222. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  223. def __iter__(self):
  224. self.count = -1
  225. return self
  226. def __next__(self):
  227. self.count += 1
  228. if cv2.waitKey(1) == ord('q'): # q to quit
  229. self.cap.release()
  230. cv2.destroyAllWindows()
  231. raise StopIteration
  232. # Read frame
  233. ret_val, img0 = self.cap.read()
  234. img0 = cv2.flip(img0, 1) # flip left-right
  235. # Print
  236. assert ret_val, f'Camera Error {self.pipe}'
  237. img_path = 'webcam.jpg'
  238. s = f'webcam {self.count}: '
  239. # Padded resize
  240. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  241. # Convert
  242. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  243. img = np.ascontiguousarray(img)
  244. return img_path, img, img0, None, s
  245. def __len__(self):
  246. return 0
  247. class LoadStreams:
  248. # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
  249. def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
  250. self.mode = 'stream'
  251. self.img_size = img_size
  252. self.stride = stride
  253. if os.path.isfile(sources):
  254. with open(sources) as f:
  255. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  256. else:
  257. sources = [sources]
  258. n = len(sources)
  259. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  260. self.sources = [clean_str(x) for x in sources] # clean source names for later
  261. self.auto = auto
  262. for i, s in enumerate(sources): # index, source
  263. # Start thread to read frames from video stream
  264. st = f'{i + 1}/{n}: {s}... '
  265. if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
  266. check_requirements(('pafy', 'youtube_dl==2020.12.2'))
  267. import pafy
  268. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  269. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  270. cap = cv2.VideoCapture(s)
  271. assert cap.isOpened(), f'{st}Failed to open {s}'
  272. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  273. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  274. fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
  275. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  276. self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
  277. _, self.imgs[i] = cap.read() # guarantee first frame
  278. self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
  279. LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
  280. self.threads[i].start()
  281. LOGGER.info('') # newline
  282. # check for common shapes
  283. s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
  284. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  285. if not self.rect:
  286. LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
  287. def update(self, i, cap, stream):
  288. # Read stream `i` frames in daemon thread
  289. n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
  290. while cap.isOpened() and n < f:
  291. n += 1
  292. # _, self.imgs[index] = cap.read()
  293. cap.grab()
  294. if n % read == 0:
  295. success, im = cap.retrieve()
  296. if success:
  297. self.imgs[i] = im
  298. else:
  299. LOGGER.warning('WARNING: Video stream unresponsive, please check your IP camera connection.')
  300. self.imgs[i] = np.zeros_like(self.imgs[i])
  301. cap.open(stream) # re-open stream if signal was lost
  302. time.sleep(1 / self.fps[i]) # wait time
  303. def __iter__(self):
  304. self.count = -1
  305. return self
  306. def __next__(self):
  307. self.count += 1
  308. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  309. cv2.destroyAllWindows()
  310. raise StopIteration
  311. # Letterbox
  312. img0 = self.imgs.copy()
  313. img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
  314. # Stack
  315. img = np.stack(img, 0)
  316. # Convert
  317. img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
  318. img = np.ascontiguousarray(img)
  319. return self.sources, img, img0, None, ''
  320. def __len__(self):
  321. return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
  322. def img2label_paths(img_paths):
  323. # Define label paths as a function of image paths
  324. sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
  325. return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
  326. class LoadImagesAndLabels(Dataset):
  327. # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
  328. cache_version = 0.6 # dataset labels *.cache version
  329. rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
  330. def __init__(self,
  331. path,
  332. img_size=640,
  333. batch_size=16,
  334. augment=False,
  335. hyp=None,
  336. rect=False,
  337. image_weights=False,
  338. cache_images=False,
  339. single_cls=False,
  340. stride=32,
  341. pad=0.0,
  342. prefix=''):
  343. self.img_size = img_size
  344. self.augment = augment
  345. self.hyp = hyp
  346. self.image_weights = image_weights
  347. self.rect = False if image_weights else rect
  348. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  349. self.mosaic_border = [-img_size // 2, -img_size // 2]
  350. self.stride = stride
  351. self.path = path
  352. self.albumentations = Albumentations() if augment else None
  353. try:
  354. f = [] # image files
  355. for p in path if isinstance(path, list) else [path]:
  356. p = Path(p) # os-agnostic
  357. if p.is_dir(): # dir
  358. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  359. # f = list(p.rglob('*.*')) # pathlib
  360. elif p.is_file(): # file
  361. with open(p) as t:
  362. t = t.read().strip().splitlines()
  363. parent = str(p.parent) + os.sep
  364. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  365. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  366. else:
  367. raise Exception(f'{prefix}{p} does not exist')
  368. self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
  369. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
  370. assert self.im_files, f'{prefix}No images found'
  371. except Exception as e:
  372. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
  373. # Check cache
  374. self.label_files = img2label_paths(self.im_files) # labels
  375. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
  376. try:
  377. cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
  378. assert cache['version'] == self.cache_version # same version
  379. assert cache['hash'] == get_hash(self.label_files + self.im_files) # same hash
  380. except Exception:
  381. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  382. # Display cache
  383. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
  384. if exists and LOCAL_RANK in {-1, 0}:
  385. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
  386. tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
  387. if cache['msgs']:
  388. LOGGER.info('\n'.join(cache['msgs'])) # display warnings
  389. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
  390. # Read cache
  391. [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
  392. labels, shapes, self.segments = zip(*cache.values())
  393. self.labels = list(labels)
  394. self.shapes = np.array(shapes, dtype=np.float64)
  395. self.im_files = list(cache.keys()) # update
  396. self.label_files = img2label_paths(cache.keys()) # update
  397. n = len(shapes) # number of images
  398. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  399. nb = bi[-1] + 1 # number of batches
  400. self.batch = bi # batch index of image
  401. self.n = n
  402. self.indices = range(n)
  403. # Update labels
  404. include_class = [] # filter labels to include only these classes (optional)
  405. include_class_array = np.array(include_class).reshape(1, -1)
  406. for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
  407. if include_class:
  408. j = (label[:, 0:1] == include_class_array).any(1)
  409. self.labels[i] = label[j]
  410. if segment:
  411. self.segments[i] = segment[j]
  412. if single_cls: # single-class training, merge all classes into 0
  413. self.labels[i][:, 0] = 0
  414. if segment:
  415. self.segments[i][:, 0] = 0
  416. # Rectangular Training
  417. if self.rect:
  418. # Sort by aspect ratio
  419. s = self.shapes # wh
  420. ar = s[:, 1] / s[:, 0] # aspect ratio
  421. irect = ar.argsort()
  422. self.im_files = [self.im_files[i] for i in irect]
  423. self.label_files = [self.label_files[i] for i in irect]
  424. self.labels = [self.labels[i] for i in irect]
  425. self.shapes = s[irect] # wh
  426. ar = ar[irect]
  427. # Set training image shapes
  428. shapes = [[1, 1]] * nb
  429. for i in range(nb):
  430. ari = ar[bi == i]
  431. mini, maxi = ari.min(), ari.max()
  432. if maxi < 1:
  433. shapes[i] = [maxi, 1]
  434. elif mini > 1:
  435. shapes[i] = [1, 1 / mini]
  436. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  437. # Cache images into RAM/disk for faster training (WARNING: large datasets may exceed system resources)
  438. self.ims = [None] * n
  439. self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
  440. if cache_images:
  441. gb = 0 # Gigabytes of cached images
  442. self.im_hw0, self.im_hw = [None] * n, [None] * n
  443. fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
  444. results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
  445. pbar = tqdm(enumerate(results), total=n, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0)
  446. for i, x in pbar:
  447. if cache_images == 'disk':
  448. gb += self.npy_files[i].stat().st_size
  449. else: # 'ram'
  450. self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
  451. gb += self.ims[i].nbytes
  452. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
  453. pbar.close()
  454. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  455. # Cache dataset labels, check images and read shapes
  456. x = {} # dict
  457. nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
  458. desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
  459. with Pool(NUM_THREADS) as pool:
  460. pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
  461. desc=desc,
  462. total=len(self.im_files),
  463. bar_format=BAR_FORMAT)
  464. for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
  465. nm += nm_f
  466. nf += nf_f
  467. ne += ne_f
  468. nc += nc_f
  469. if im_file:
  470. x[im_file] = [lb, shape, segments]
  471. if msg:
  472. msgs.append(msg)
  473. pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
  474. pbar.close()
  475. if msgs:
  476. LOGGER.info('\n'.join(msgs))
  477. if nf == 0:
  478. LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
  479. x['hash'] = get_hash(self.label_files + self.im_files)
  480. x['results'] = nf, nm, ne, nc, len(self.im_files)
  481. x['msgs'] = msgs # warnings
  482. x['version'] = self.cache_version # cache version
  483. try:
  484. np.save(path, x) # save cache for next time
  485. path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
  486. LOGGER.info(f'{prefix}New cache created: {path}')
  487. except Exception as e:
  488. LOGGER.warning(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # not writeable
  489. return x
  490. def __len__(self):
  491. return len(self.im_files)
  492. # def __iter__(self):
  493. # self.count = -1
  494. # print('ran dataset iter')
  495. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  496. # return self
  497. def __getitem__(self, index):
  498. index = self.indices[index] # linear, shuffled, or image_weights
  499. hyp = self.hyp
  500. mosaic = self.mosaic and random.random() < hyp['mosaic']
  501. if mosaic:
  502. # Load mosaic
  503. img, labels = self.load_mosaic(index)
  504. shapes = None
  505. # MixUp augmentation
  506. if random.random() < hyp['mixup']:
  507. img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
  508. else:
  509. # Load image
  510. img, (h0, w0), (h, w) = self.load_image(index)
  511. # Letterbox
  512. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  513. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  514. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  515. labels = self.labels[index].copy()
  516. if labels.size: # normalized xywh to pixel xyxy format
  517. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  518. if self.augment:
  519. img, labels = random_perspective(img,
  520. labels,
  521. degrees=hyp['degrees'],
  522. translate=hyp['translate'],
  523. scale=hyp['scale'],
  524. shear=hyp['shear'],
  525. perspective=hyp['perspective'])
  526. nl = len(labels) # number of labels
  527. if nl:
  528. labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
  529. if self.augment:
  530. # Albumentations
  531. img, labels = self.albumentations(img, labels)
  532. nl = len(labels) # update after albumentations
  533. # HSV color-space
  534. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  535. # Flip up-down
  536. if random.random() < hyp['flipud']:
  537. img = np.flipud(img)
  538. if nl:
  539. labels[:, 2] = 1 - labels[:, 2]
  540. # Flip left-right
  541. if random.random() < hyp['fliplr']:
  542. img = np.fliplr(img)
  543. if nl:
  544. labels[:, 1] = 1 - labels[:, 1]
  545. # Cutouts
  546. # labels = cutout(img, labels, p=0.5)
  547. # nl = len(labels) # update after cutout
  548. labels_out = torch.zeros((nl, 6))
  549. if nl:
  550. labels_out[:, 1:] = torch.from_numpy(labels)
  551. # Convert
  552. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  553. img = np.ascontiguousarray(img)
  554. return torch.from_numpy(img), labels_out, self.im_files[index], shapes
  555. def load_image(self, i):
  556. # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
  557. im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
  558. if im is None: # not cached in RAM
  559. if fn.exists(): # load npy
  560. im = np.load(fn)
  561. else: # read image
  562. im = cv2.imread(f) # BGR
  563. assert im is not None, f'Image Not Found {f}'
  564. h0, w0 = im.shape[:2] # orig hw
  565. r = self.img_size / max(h0, w0) # ratio
  566. if r != 1: # if sizes are not equal
  567. interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
  568. im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
  569. return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
  570. else:
  571. return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
  572. def cache_images_to_disk(self, i):
  573. # Saves an image as an *.npy file for faster loading
  574. f = self.npy_files[i]
  575. if not f.exists():
  576. np.save(f.as_posix(), cv2.imread(self.im_files[i]))
  577. def load_mosaic(self, index):
  578. # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
  579. labels4, segments4 = [], []
  580. s = self.img_size
  581. yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
  582. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  583. random.shuffle(indices)
  584. for i, index in enumerate(indices):
  585. # Load image
  586. img, _, (h, w) = self.load_image(index)
  587. # place img in img4
  588. if i == 0: # top left
  589. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  590. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  591. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  592. elif i == 1: # top right
  593. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  594. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  595. elif i == 2: # bottom left
  596. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  597. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  598. elif i == 3: # bottom right
  599. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  600. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  601. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  602. padw = x1a - x1b
  603. padh = y1a - y1b
  604. # Labels
  605. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  606. if labels.size:
  607. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  608. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  609. labels4.append(labels)
  610. segments4.extend(segments)
  611. # Concat/clip labels
  612. labels4 = np.concatenate(labels4, 0)
  613. for x in (labels4[:, 1:], *segments4):
  614. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  615. # img4, labels4 = replicate(img4, labels4) # replicate
  616. # Augment
  617. img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
  618. img4, labels4 = random_perspective(img4,
  619. labels4,
  620. segments4,
  621. degrees=self.hyp['degrees'],
  622. translate=self.hyp['translate'],
  623. scale=self.hyp['scale'],
  624. shear=self.hyp['shear'],
  625. perspective=self.hyp['perspective'],
  626. border=self.mosaic_border) # border to remove
  627. return img4, labels4
  628. def load_mosaic9(self, index):
  629. # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
  630. labels9, segments9 = [], []
  631. s = self.img_size
  632. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  633. random.shuffle(indices)
  634. hp, wp = -1, -1 # height, width previous
  635. for i, index in enumerate(indices):
  636. # Load image
  637. img, _, (h, w) = self.load_image(index)
  638. # place img in img9
  639. if i == 0: # center
  640. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  641. h0, w0 = h, w
  642. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  643. elif i == 1: # top
  644. c = s, s - h, s + w, s
  645. elif i == 2: # top right
  646. c = s + wp, s - h, s + wp + w, s
  647. elif i == 3: # right
  648. c = s + w0, s, s + w0 + w, s + h
  649. elif i == 4: # bottom right
  650. c = s + w0, s + hp, s + w0 + w, s + hp + h
  651. elif i == 5: # bottom
  652. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  653. elif i == 6: # bottom left
  654. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  655. elif i == 7: # left
  656. c = s - w, s + h0 - h, s, s + h0
  657. elif i == 8: # top left
  658. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  659. padx, pady = c[:2]
  660. x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
  661. # Labels
  662. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  663. if labels.size:
  664. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  665. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  666. labels9.append(labels)
  667. segments9.extend(segments)
  668. # Image
  669. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  670. hp, wp = h, w # height, width previous
  671. # Offset
  672. yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
  673. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  674. # Concat/clip labels
  675. labels9 = np.concatenate(labels9, 0)
  676. labels9[:, [1, 3]] -= xc
  677. labels9[:, [2, 4]] -= yc
  678. c = np.array([xc, yc]) # centers
  679. segments9 = [x - c for x in segments9]
  680. for x in (labels9[:, 1:], *segments9):
  681. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  682. # img9, labels9 = replicate(img9, labels9) # replicate
  683. # Augment
  684. img9, labels9 = random_perspective(img9,
  685. labels9,
  686. segments9,
  687. degrees=self.hyp['degrees'],
  688. translate=self.hyp['translate'],
  689. scale=self.hyp['scale'],
  690. shear=self.hyp['shear'],
  691. perspective=self.hyp['perspective'],
  692. border=self.mosaic_border) # border to remove
  693. return img9, labels9
  694. @staticmethod
  695. def collate_fn(batch):
  696. im, label, path, shapes = zip(*batch) # transposed
  697. for i, lb in enumerate(label):
  698. lb[:, 0] = i # add target image index for build_targets()
  699. return torch.stack(im, 0), torch.cat(label, 0), path, shapes
  700. @staticmethod
  701. def collate_fn4(batch):
  702. img, label, path, shapes = zip(*batch) # transposed
  703. n = len(shapes) // 4
  704. im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  705. ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
  706. wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
  707. s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
  708. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  709. i *= 4
  710. if random.random() < 0.5:
  711. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
  712. align_corners=False)[0].type(img[i].type())
  713. lb = label[i]
  714. else:
  715. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  716. lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  717. im4.append(im)
  718. label4.append(lb)
  719. for i, lb in enumerate(label4):
  720. lb[:, 0] = i # add target image index for build_targets()
  721. return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
  722. # Ancillary functions --------------------------------------------------------------------------------------------------
  723. def create_folder(path='./new'):
  724. # Create folder
  725. if os.path.exists(path):
  726. shutil.rmtree(path) # delete output folder
  727. os.makedirs(path) # make new output folder
  728. def flatten_recursive(path=DATASETS_DIR / 'coco128'):
  729. # Flatten a recursive directory by bringing all files to top level
  730. new_path = Path(str(path) + '_flat')
  731. create_folder(new_path)
  732. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  733. shutil.copyfile(file, new_path / Path(file).name)
  734. def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.dataloaders import *; extract_boxes()
  735. # Convert detection dataset into classification dataset, with one directory per class
  736. path = Path(path) # images dir
  737. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  738. files = list(path.rglob('*.*'))
  739. n = len(files) # number of files
  740. for im_file in tqdm(files, total=n):
  741. if im_file.suffix[1:] in IMG_FORMATS:
  742. # image
  743. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  744. h, w = im.shape[:2]
  745. # labels
  746. lb_file = Path(img2label_paths([str(im_file)])[0])
  747. if Path(lb_file).exists():
  748. with open(lb_file) as f:
  749. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  750. for j, x in enumerate(lb):
  751. c = int(x[0]) # class
  752. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  753. if not f.parent.is_dir():
  754. f.parent.mkdir(parents=True)
  755. b = x[1:] * [w, h, w, h] # box
  756. # b[2:] = b[2:].max() # rectangle to square
  757. b[2:] = b[2:] * 1.2 + 3 # pad
  758. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  759. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  760. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  761. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  762. def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
  763. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  764. Usage: from utils.dataloaders import *; autosplit()
  765. Arguments
  766. path: Path to images directory
  767. weights: Train, val, test weights (list, tuple)
  768. annotated_only: Only use images with an annotated txt file
  769. """
  770. path = Path(path) # images dir
  771. files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
  772. n = len(files) # number of files
  773. random.seed(0) # for reproducibility
  774. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  775. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  776. [(path.parent / x).unlink(missing_ok=True) for x in txt] # remove existing
  777. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  778. for i, img in tqdm(zip(indices, files), total=n):
  779. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  780. with open(path.parent / txt[i], 'a') as f:
  781. f.write('./' + img.relative_to(path.parent).as_posix() + '\n') # add image to txt file
  782. def verify_image_label(args):
  783. # Verify one image-label pair
  784. im_file, lb_file, prefix = args
  785. nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
  786. try:
  787. # verify images
  788. im = Image.open(im_file)
  789. im.verify() # PIL verify
  790. shape = exif_size(im) # image size
  791. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  792. assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
  793. if im.format.lower() in ('jpg', 'jpeg'):
  794. with open(im_file, 'rb') as f:
  795. f.seek(-2, 2)
  796. if f.read() != b'\xff\xd9': # corrupt JPEG
  797. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
  798. msg = f'{prefix}WARNING: {im_file}: corrupt JPEG restored and saved'
  799. # verify labels
  800. if os.path.isfile(lb_file):
  801. nf = 1 # label found
  802. with open(lb_file) as f:
  803. lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
  804. if any(len(x) > 6 for x in lb): # is segment
  805. classes = np.array([x[0] for x in lb], dtype=np.float32)
  806. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
  807. lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  808. lb = np.array(lb, dtype=np.float32)
  809. nl = len(lb)
  810. if nl:
  811. assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
  812. assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
  813. assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
  814. _, i = np.unique(lb, axis=0, return_index=True)
  815. if len(i) < nl: # duplicate row check
  816. lb = lb[i] # remove duplicates
  817. if segments:
  818. segments = segments[i]
  819. msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
  820. else:
  821. ne = 1 # label empty
  822. lb = np.zeros((0, 5), dtype=np.float32)
  823. else:
  824. nm = 1 # label missing
  825. lb = np.zeros((0, 5), dtype=np.float32)
  826. return im_file, lb, shape, segments, nm, nf, ne, nc, msg
  827. except Exception as e:
  828. nc = 1
  829. msg = f'{prefix}WARNING: {im_file}: ignoring corrupt image/label: {e}'
  830. return [None, None, None, None, nm, nf, ne, nc, msg]
  831. def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profile=False, hub=False):
  832. """ Return dataset statistics dictionary with images and instances counts per split per class
  833. To run in parent directory: export PYTHONPATH="$PWD/yolov5"
  834. Usage1: from utils.dataloaders import *; dataset_stats('coco128.yaml', autodownload=True)
  835. Usage2: from utils.dataloaders import *; dataset_stats('path/to/coco128_with_yaml.zip')
  836. Arguments
  837. path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
  838. autodownload: Attempt to download dataset if not found locally
  839. verbose: Print stats dictionary
  840. """
  841. def _round_labels(labels):
  842. # Update labels to integer class and 6 decimal place floats
  843. return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
  844. def _find_yaml(dir):
  845. # Return data.yaml file
  846. files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
  847. assert files, f'No *.yaml file found in {dir}'
  848. if len(files) > 1:
  849. files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
  850. assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
  851. assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
  852. return files[0]
  853. def _unzip(path):
  854. # Unzip data.zip
  855. if str(path).endswith('.zip'): # path is data.zip
  856. assert Path(path).is_file(), f'Error unzipping {path}, file not found'
  857. ZipFile(path).extractall(path=path.parent) # unzip
  858. dir = path.with_suffix('') # dataset directory == zip name
  859. assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
  860. return True, str(dir), _find_yaml(dir) # zipped, data_dir, yaml_path
  861. else: # path is data.yaml
  862. return False, None, path
  863. def _hub_ops(f, max_dim=1920):
  864. # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
  865. f_new = im_dir / Path(f).name # dataset-hub image filename
  866. try: # use PIL
  867. im = Image.open(f)
  868. r = max_dim / max(im.height, im.width) # ratio
  869. if r < 1.0: # image too large
  870. im = im.resize((int(im.width * r), int(im.height * r)))
  871. im.save(f_new, 'JPEG', quality=75, optimize=True) # save
  872. except Exception as e: # use OpenCV
  873. print(f'WARNING: HUB ops PIL failure {f}: {e}')
  874. im = cv2.imread(f)
  875. im_height, im_width = im.shape[:2]
  876. r = max_dim / max(im_height, im_width) # ratio
  877. if r < 1.0: # image too large
  878. im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
  879. cv2.imwrite(str(f_new), im)
  880. zipped, data_dir, yaml_path = _unzip(Path(path))
  881. with open(check_yaml(yaml_path), errors='ignore') as f:
  882. data = yaml.safe_load(f) # data dict
  883. if zipped:
  884. data['path'] = data_dir # TODO: should this be dir.resolve()?
  885. check_dataset(data, autodownload) # download dataset if missing
  886. hub_dir = Path(data['path'] + ('-hub' if hub else ''))
  887. stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
  888. for split in 'train', 'val', 'test':
  889. if data.get(split) is None:
  890. stats[split] = None # i.e. no test set
  891. continue
  892. x = []
  893. dataset = LoadImagesAndLabels(data[split]) # load dataset
  894. for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
  895. x.append(np.bincount(label[:, 0].astype(int), minlength=data['nc']))
  896. x = np.array(x) # shape(128x80)
  897. stats[split] = {
  898. 'instance_stats': {
  899. 'total': int(x.sum()),
  900. 'per_class': x.sum(0).tolist()},
  901. 'image_stats': {
  902. 'total': dataset.n,
  903. 'unlabelled': int(np.all(x == 0, 1).sum()),
  904. 'per_class': (x > 0).sum(0).tolist()},
  905. 'labels': [{
  906. str(Path(k).name): _round_labels(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
  907. if hub:
  908. im_dir = hub_dir / 'images'
  909. im_dir.mkdir(parents=True, exist_ok=True)
  910. for _ in tqdm(ThreadPool(NUM_THREADS).imap(_hub_ops, dataset.im_files), total=dataset.n, desc='HUB Ops'):
  911. pass
  912. # Profile
  913. stats_path = hub_dir / 'stats.json'
  914. if profile:
  915. for _ in range(1):
  916. file = stats_path.with_suffix('.npy')
  917. t1 = time.time()
  918. np.save(file, stats)
  919. t2 = time.time()
  920. x = np.load(file, allow_pickle=True)
  921. print(f'stats.npy times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  922. file = stats_path.with_suffix('.json')
  923. t1 = time.time()
  924. with open(file, 'w') as f:
  925. json.dump(stats, f) # save stats *.json
  926. t2 = time.time()
  927. with open(file) as f:
  928. x = json.load(f) # load hyps dict
  929. print(f'stats.json times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  930. # Save, print and return
  931. if hub:
  932. print(f'Saving {stats_path.resolve()}...')
  933. with open(stats_path, 'w') as f:
  934. json.dump(stats, f) # save stats.json
  935. if verbose:
  936. print(json.dumps(stats, indent=2, sort_keys=False))
  937. return stats