general.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import inspect
  8. import logging
  9. import math
  10. import os
  11. import platform
  12. import random
  13. import re
  14. import shutil
  15. import signal
  16. import time
  17. import urllib
  18. from datetime import datetime
  19. from itertools import repeat
  20. from multiprocessing.pool import ThreadPool
  21. from pathlib import Path
  22. from subprocess import check_output
  23. from typing import Optional
  24. from zipfile import ZipFile
  25. import cv2
  26. import numpy as np
  27. import pandas as pd
  28. import pkg_resources as pkg
  29. import torch
  30. import torchvision
  31. import yaml
  32. from utils.downloads import gsutil_getsize
  33. from utils.metrics import box_iou, fitness
  34. # Settings
  35. FILE = Path(__file__).resolve()
  36. ROOT = FILE.parents[1] # YOLOv5 root directory
  37. DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
  38. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
  39. AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
  40. VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
  41. FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
  42. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  43. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  44. pd.options.display.max_columns = 10
  45. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) 使用cv2单线程
  46. os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
  47. os.environ['OMP_NUM_THREADS'] = str(NUM_THREADS) # OpenMP max threads (PyTorch and SciPy)
  48. def is_kaggle():
  49. # Is environment a Kaggle Notebook?
  50. try:
  51. assert os.environ.get('PWD') == '/kaggle/working'
  52. assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
  53. return True
  54. except AssertionError:
  55. return False
  56. def is_writeable(dir, test=False):
  57. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  58. if test: # method 1
  59. file = Path(dir) / 'tmp.txt'
  60. try:
  61. with open(file, 'w'): # open file with write permissions
  62. pass
  63. file.unlink() # remove file
  64. return True
  65. except OSError:
  66. return False
  67. else: # method 2
  68. return os.access(dir, os.R_OK) # possible issues on Windows
  69. def set_logging(name=None, verbose=VERBOSE):
  70. # Sets level and returns logger
  71. if is_kaggle():
  72. for h in logging.root.handlers:
  73. logging.root.removeHandler(h) # remove all handlers associated with the root logger object
  74. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  75. level = logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING
  76. log = logging.getLogger(name)
  77. log.setLevel(level)
  78. handler = logging.StreamHandler()
  79. handler.setFormatter(logging.Formatter("%(message)s"))
  80. handler.setLevel(level)
  81. log.addHandler(handler)
  82. set_logging() # run before defining LOGGER
  83. LOGGER = logging.getLogger("yolov5") # define globally (used in train.py, val.py, detect.py, etc.)
  84. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  85. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  86. env = os.getenv(env_var)
  87. if env:
  88. path = Path(env) # use environment variable
  89. else:
  90. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  91. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  92. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  93. path.mkdir(exist_ok=True) # make if required
  94. return path
  95. CONFIG_DIR = user_config_dir() # Ultralytics settings dir
  96. class Profile(contextlib.ContextDecorator):
  97. # Usage: @Profile() decorator or 'with Profile():' context manager
  98. def __enter__(self):
  99. self.start = time.time()
  100. def __exit__(self, type, value, traceback):
  101. print(f'Profile results: {time.time() - self.start:.5f}s')
  102. class Timeout(contextlib.ContextDecorator):
  103. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  104. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  105. self.seconds = int(seconds)
  106. self.timeout_message = timeout_msg
  107. self.suppress = bool(suppress_timeout_errors)
  108. def _timeout_handler(self, signum, frame):
  109. raise TimeoutError(self.timeout_message)
  110. def __enter__(self):
  111. if platform.system() != 'Windows': # not supported on Windows
  112. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  113. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  114. def __exit__(self, exc_type, exc_val, exc_tb):
  115. if platform.system() != 'Windows':
  116. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  117. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  118. return True
  119. class WorkingDirectory(contextlib.ContextDecorator):
  120. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  121. def __init__(self, new_dir):
  122. self.dir = new_dir # new dir
  123. self.cwd = Path.cwd().resolve() # current dir
  124. def __enter__(self):
  125. os.chdir(self.dir)
  126. def __exit__(self, exc_type, exc_val, exc_tb):
  127. os.chdir(self.cwd)
  128. def try_except(func):
  129. # try-except function. Usage: @try_except decorator
  130. def handler(*args, **kwargs):
  131. try:
  132. func(*args, **kwargs)
  133. except Exception as e:
  134. print(e)
  135. return handler
  136. def methods(instance):
  137. # Get class/instance methods
  138. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  139. def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
  140. # Print function arguments (optional args dict)
  141. x = inspect.currentframe().f_back # previous frame
  142. file, _, fcn, _, _ = inspect.getframeinfo(x)
  143. if args is None: # get args automatically
  144. args, _, _, frm = inspect.getargvalues(x)
  145. args = {k: v for k, v in frm.items() if k in args}
  146. s = (f'{Path(file).stem}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '')
  147. LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
  148. def init_seeds(seed=0):
  149. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  150. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  151. import torch.backends.cudnn as cudnn
  152. random.seed(seed)
  153. np.random.seed(seed)
  154. torch.manual_seed(seed)
  155. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  156. def intersect_dicts(da, db, exclude=()):
  157. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  158. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  159. def get_latest_run(search_dir='.'):
  160. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  161. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  162. return max(last_list, key=os.path.getctime) if last_list else ''
  163. def is_docker():
  164. # Is environment a Docker container?
  165. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  166. def is_colab():
  167. # Is environment a Google Colab instance?
  168. try:
  169. import google.colab
  170. return True
  171. except ImportError:
  172. return False
  173. def is_pip():
  174. # Is file in a pip package?
  175. return 'site-packages' in Path(__file__).resolve().parts
  176. def is_ascii(s=''):
  177. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  178. s = str(s) # convert list, tuple, None, etc. to str
  179. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  180. def is_chinese(s='人工智能'):
  181. # Is string composed of any Chinese characters?
  182. return True if re.search('[\u4e00-\u9fff]', str(s)) else False
  183. def emojis(str=''):
  184. # Return platform-dependent emoji-safe version of string
  185. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  186. def file_age(path=__file__):
  187. # Return days since last file update
  188. dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
  189. return dt.days # + dt.seconds / 86400 # fractional days
  190. def file_update_date(path=__file__):
  191. # Return human-readable file modification date, i.e. '2021-3-26'
  192. t = datetime.fromtimestamp(Path(path).stat().st_mtime)
  193. return f'{t.year}-{t.month}-{t.day}'
  194. def file_size(path):
  195. # Return file/dir size (MB)
  196. mb = 1 << 20 # bytes to MiB (1024 ** 2)
  197. path = Path(path)
  198. if path.is_file():
  199. return path.stat().st_size / mb
  200. elif path.is_dir():
  201. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
  202. else:
  203. return 0.0
  204. def check_online():
  205. # Check internet connectivity
  206. import socket
  207. try:
  208. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  209. return True
  210. except OSError:
  211. return False
  212. def git_describe(path=ROOT): # path must be a directory
  213. # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  214. try:
  215. assert (Path(path) / '.git').is_dir()
  216. return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
  217. except Exception:
  218. return ''
  219. @try_except
  220. @WorkingDirectory(ROOT)
  221. def check_git_status():
  222. # Recommend 'git pull' if code is out of date
  223. msg = ', for updates see https://github.com/ultralytics/yolov5'
  224. s = colorstr('github: ') # string
  225. assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
  226. assert not is_docker(), s + 'skipping check (Docker image)' + msg
  227. assert check_online(), s + 'skipping check (offline)' + msg
  228. cmd = 'git fetch && git config --get remote.origin.url'
  229. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  230. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  231. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  232. if n > 0:
  233. s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  234. else:
  235. s += f'up to date with {url} ✅'
  236. LOGGER.info(emojis(s)) # emoji-safe
  237. def check_python(minimum='3.7.0'):
  238. # Check current python version vs. required python version
  239. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  240. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
  241. # Check version vs. required version
  242. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  243. result = (current == minimum) if pinned else (current >= minimum) # bool
  244. s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
  245. if hard:
  246. assert result, s # assert min requirements met
  247. if verbose and not result:
  248. LOGGER.warning(s)
  249. return result
  250. @try_except
  251. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
  252. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  253. prefix = colorstr('red', 'bold', 'requirements:')
  254. check_python() # check python version
  255. if isinstance(requirements, (str, Path)): # requirements.txt file
  256. file = Path(requirements)
  257. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  258. with file.open() as f:
  259. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
  260. else: # list or tuple of packages
  261. requirements = [x for x in requirements if x not in exclude]
  262. n = 0 # number of packages updates
  263. for i, r in enumerate(requirements):
  264. try:
  265. pkg.require(r)
  266. except Exception: # DistributionNotFound or VersionConflict if requirements not met
  267. s = f"{prefix} {r} not found and is required by YOLOv5"
  268. if install and AUTOINSTALL: # check environment variable
  269. LOGGER.info(f"{s}, attempting auto-update...")
  270. try:
  271. assert check_online(), f"'pip install {r}' skipped (offline)"
  272. LOGGER.info(check_output(f"pip install '{r}' {cmds[i] if cmds else ''}", shell=True).decode())
  273. n += 1
  274. except Exception as e:
  275. LOGGER.warning(f'{prefix} {e}')
  276. else:
  277. LOGGER.info(f'{s}. Please install and rerun your command.')
  278. if n: # if packages updated
  279. source = file.resolve() if 'file' in locals() else requirements
  280. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  281. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  282. LOGGER.info(emojis(s))
  283. def check_img_size(imgsz, s=32, floor=0): #检查图像尺寸是否是s的整数倍;不是则调整图像尺寸
  284. # Verify image size is a multiple of stride s in each dimension
  285. if isinstance(imgsz, int): # integer i.e. img_size=640
  286. new_size = max(make_divisible(imgsz, int(s)), floor) # 返回能被s整除的最近的整数
  287. else: # list i.e. img_size=[640, 480]
  288. imgsz = list(imgsz) # convert to list if tuple
  289. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  290. if new_size != imgsz:
  291. LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  292. return new_size
  293. def check_imshow():
  294. # Check if environment supports image displays
  295. try:
  296. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  297. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  298. cv2.imshow('test', np.zeros((1, 1, 3)))
  299. cv2.waitKey(1)
  300. cv2.destroyAllWindows()
  301. cv2.waitKey(1)
  302. return True
  303. except Exception as e:
  304. LOGGER.warning(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  305. return False
  306. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  307. # Check file(s) for acceptable suffix
  308. if file and suffix:
  309. if isinstance(suffix, str):
  310. suffix = [suffix]
  311. for f in file if isinstance(file, (list, tuple)) else [file]:
  312. s = Path(f).suffix.lower() # file suffix
  313. if len(s):
  314. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  315. def check_yaml(file, suffix=('.yaml', '.yml')):
  316. # Search/download YAML file (if necessary) and return path, checking suffix
  317. return check_file(file, suffix)
  318. def check_file(file, suffix=''):
  319. # Search/download file (if necessary) and return path
  320. check_suffix(file, suffix) # optional
  321. file = str(file) # convert to str()
  322. if Path(file).is_file() or file == '': # exists
  323. return file
  324. elif file.startswith(('http:/', 'https:/')): # download
  325. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  326. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  327. if Path(file).is_file():
  328. LOGGER.info(f'Found {url} locally at {file}') # file already exists
  329. else:
  330. LOGGER.info(f'Downloading {url} to {file}...')
  331. torch.hub.download_url_to_file(url, file)
  332. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  333. return file
  334. else: # search
  335. files = []
  336. for d in 'data', 'models', 'utils': # search directories
  337. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  338. assert len(files), f'File not found: {file}' # assert file was found
  339. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  340. return files[0] # return file
  341. def check_font(font=FONT, progress=False):
  342. # Download font to CONFIG_DIR if necessary
  343. font = Path(font)
  344. file = CONFIG_DIR / font.name
  345. if not font.exists() and not file.exists():
  346. url = "https://ultralytics.com/assets/" + font.name
  347. LOGGER.info(f'Downloading {url} to {file}...')
  348. torch.hub.download_url_to_file(url, str(file), progress=progress)
  349. def check_dataset(data, autodownload=True):
  350. # Download and/or unzip dataset if not found locally
  351. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  352. # Download (optional)
  353. extract_dir = ''
  354. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  355. download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1)
  356. data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
  357. extract_dir, autodownload = data.parent, False
  358. # Read yaml (optional)
  359. if isinstance(data, (str, Path)):
  360. with open(data, errors='ignore') as f:
  361. data = yaml.safe_load(f) # dictionary
  362. # Resolve paths
  363. path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
  364. if not path.is_absolute():
  365. path = (ROOT / path).resolve()
  366. for k in 'train', 'val', 'test':
  367. if data.get(k): # prepend path
  368. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  369. # Parse yaml
  370. assert 'nc' in data, "Dataset 'nc' key missing."
  371. if 'names' not in data:
  372. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  373. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  374. if val:
  375. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  376. if not all(x.exists() for x in val):
  377. LOGGER.info(emojis('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()]))
  378. if s and autodownload: # download script
  379. t = time.time()
  380. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  381. if s.startswith('http') and s.endswith('.zip'): # URL
  382. f = Path(s).name # filename
  383. LOGGER.info(f'Downloading {s} to {f}...')
  384. torch.hub.download_url_to_file(s, f)
  385. Path(root).mkdir(parents=True, exist_ok=True) # create root
  386. ZipFile(f).extractall(path=root) # unzip
  387. Path(f).unlink() # remove zip
  388. r = None # success
  389. elif s.startswith('bash '): # bash script
  390. LOGGER.info(f'Running {s} ...')
  391. r = os.system(s)
  392. else: # python script
  393. r = exec(s, {'yaml': data}) # return None
  394. dt = f'({round(time.time() - t, 1)}s)'
  395. s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌"
  396. LOGGER.info(emojis(f"Dataset download {s}"))
  397. else:
  398. raise Exception(emojis('Dataset not found ❌'))
  399. return data # dictionary
  400. def url2file(url):
  401. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  402. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  403. file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  404. return file
  405. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
  406. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  407. def download_one(url, dir):
  408. # Download 1 file
  409. success = True
  410. f = dir / Path(url).name # filename
  411. if Path(url).is_file(): # exists in current path
  412. Path(url).rename(f) # move to dir
  413. elif not f.exists():
  414. LOGGER.info(f'Downloading {url} to {f}...')
  415. for i in range(retry + 1):
  416. if curl:
  417. s = 'sS' if threads > 1 else '' # silent
  418. r = os.system(f"curl -{s}L '{url}' -o '{f}' --retry 9 -C -") # curl download
  419. success = r == 0
  420. else:
  421. torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
  422. success = f.is_file()
  423. if success:
  424. break
  425. elif i < retry:
  426. LOGGER.warning(f'Download failure, retrying {i + 1}/{retry} {url}...')
  427. else:
  428. LOGGER.warning(f'Failed to download {url}...')
  429. if unzip and success and f.suffix in ('.zip', '.gz'):
  430. LOGGER.info(f'Unzipping {f}...')
  431. if f.suffix == '.zip':
  432. ZipFile(f).extractall(path=dir) # unzip
  433. elif f.suffix == '.gz':
  434. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  435. if delete:
  436. f.unlink() # remove zip
  437. dir = Path(dir)
  438. dir.mkdir(parents=True, exist_ok=True) # make directory
  439. if threads > 1:
  440. pool = ThreadPool(threads)
  441. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  442. pool.close()
  443. pool.join()
  444. else:
  445. for u in [url] if isinstance(url, (str, Path)) else url:
  446. download_one(u, dir)
  447. def make_divisible(x, divisor): # 返回能被除数整除且最接近x的整数
  448. # Returns nearest x divisible by divisor
  449. if isinstance(divisor, torch.Tensor):
  450. divisor = int(divisor.max()) # to int
  451. return math.ceil(x / divisor) * divisor
  452. def clean_str(s):
  453. # Cleans a string by replacing special characters with underscore _
  454. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  455. def one_cycle(y1=0.0, y2=1.0, steps=100):
  456. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  457. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  458. def colorstr(*input):
  459. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  460. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  461. colors = {
  462. 'black': '\033[30m', # basic colors
  463. 'red': '\033[31m',
  464. 'green': '\033[32m',
  465. 'yellow': '\033[33m',
  466. 'blue': '\033[34m',
  467. 'magenta': '\033[35m',
  468. 'cyan': '\033[36m',
  469. 'white': '\033[37m',
  470. 'bright_black': '\033[90m', # bright colors
  471. 'bright_red': '\033[91m',
  472. 'bright_green': '\033[92m',
  473. 'bright_yellow': '\033[93m',
  474. 'bright_blue': '\033[94m',
  475. 'bright_magenta': '\033[95m',
  476. 'bright_cyan': '\033[96m',
  477. 'bright_white': '\033[97m',
  478. 'end': '\033[0m', # misc
  479. 'bold': '\033[1m',
  480. 'underline': '\033[4m'}
  481. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  482. def labels_to_class_weights(labels, nc=80):
  483. # Get class weights (inverse frequency) from training labels
  484. if labels[0] is None: # no labels loaded
  485. return torch.Tensor()
  486. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  487. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  488. weights = np.bincount(classes, minlength=nc) # occurrences per class
  489. # Prepend gridpoint count (for uCE training)
  490. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  491. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  492. weights[weights == 0] = 1 # replace empty bins with 1
  493. weights = 1 / weights # number of targets per class
  494. weights /= weights.sum() # normalize
  495. return torch.from_numpy(weights)
  496. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  497. # Produces image weights based on class_weights and image contents
  498. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels]) # 统计每个类别的样本数目
  499. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  500. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  501. return image_weights
  502. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  503. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  504. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  505. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  506. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  507. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  508. x = [
  509. 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  510. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  511. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  512. return x
  513. def xyxy2xywh(x):
  514. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  515. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  516. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  517. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  518. y[:, 2] = x[:, 2] - x[:, 0] # width
  519. y[:, 3] = x[:, 3] - x[:, 1] # height
  520. return y
  521. def xywh2xyxy(x):
  522. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  523. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  524. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  525. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  526. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  527. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  528. return y
  529. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  530. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  531. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  532. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  533. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  534. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  535. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  536. return y
  537. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  538. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  539. if clip:
  540. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  541. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  542. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  543. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  544. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  545. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  546. return y
  547. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  548. # Convert normalized segments into pixel segments, shape (n,2)
  549. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  550. y[:, 0] = w * x[:, 0] + padw # top left x
  551. y[:, 1] = h * x[:, 1] + padh # top left y
  552. return y
  553. def segment2box(segment, width=640, height=640):
  554. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  555. x, y = segment.T # segment xy
  556. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  557. x, y, = x[inside], y[inside]
  558. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  559. def segments2boxes(segments):
  560. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  561. boxes = []
  562. for s in segments:
  563. x, y = s.T # segment xy
  564. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  565. return xyxy2xywh(np.array(boxes)) # cls, xywh
  566. def resample_segments(segments, n=1000):
  567. # Up-sample an (n,2) segment
  568. for i, s in enumerate(segments):
  569. x = np.linspace(0, len(s) - 1, n)
  570. xp = np.arange(len(s))
  571. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  572. return segments
  573. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  574. # Rescale coords (xyxy) from img1_shape to img0_shape
  575. if ratio_pad is None: # calculate from img0_shape
  576. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  577. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  578. else:
  579. gain = ratio_pad[0][0]
  580. pad = ratio_pad[1]
  581. coords[:, [0, 2]] -= pad[0] # x padding
  582. coords[:, [1, 3]] -= pad[1] # y padding
  583. coords[:, :4] /= gain
  584. clip_coords(coords, img0_shape)
  585. return coords
  586. def clip_coords(boxes, shape):
  587. # Clip bounding xyxy bounding boxes to image shape (height, width)
  588. if isinstance(boxes, torch.Tensor): # faster individually
  589. boxes[:, 0].clamp_(0, shape[1]) # x1
  590. boxes[:, 1].clamp_(0, shape[0]) # y1
  591. boxes[:, 2].clamp_(0, shape[1]) # x2
  592. boxes[:, 3].clamp_(0, shape[0]) # y2
  593. else: # np.array (faster grouped)
  594. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  595. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  596. # 非极大抑制算法
  597. def non_max_suppression(prediction,
  598. conf_thres=0.25,
  599. iou_thres=0.45,
  600. classes=None,
  601. agnostic=False,
  602. multi_label=False,
  603. labels=(),
  604. max_det=300):
  605. """
  606. :param prediction: 网络的预测结果
  607. :type prediction:
  608. :param conf_thres: 置信度阈值
  609. :type conf_thres:
  610. :param iou_thres: IOU阈值
  611. :type iou_thres:
  612. :param classes: 是否只保留特定的类别
  613. :type classes:
  614. :param agnostic: 进行nms时是否也去西湖不同类别之间的框
  615. :type agnostic:
  616. :param multi_label: 多标签框预测
  617. :type multi_label:
  618. :param labels: 类别
  619. :type labels:
  620. :param max_det: 保留最大检测框数
  621. :type max_det:
  622. :return:
  623. :rtype:
  624. """
  625. """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes
  626. Returns:
  627. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  628. """
  629. bs = prediction.shape[0] # batch size
  630. nc = prediction.shape[2] - 5 # number of classes
  631. xc = prediction[..., 4] > conf_thres # candidates prediction维度[1, 18900, 85] xc维度[1, 18900]
  632. # Checks
  633. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  634. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  635. # Settings
  636. # min_wh = 2 # (pixels) minimum box width and height 最小宽高
  637. max_wh = 7680 # (pixels) maximum box width and height 最大宽高
  638. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() 最大检测框数
  639. time_limit = 0.1 + 0.03 * bs # seconds to quit after 检测最大时长
  640. redundant = True # require redundant detections
  641. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) 预测框的多标签类别,默认为False
  642. merge = False # use merge-NMS
  643. t = time.time()
  644. output = [torch.zeros((0, 6), device=prediction.device)] * bs
  645. for xi, x in enumerate(prediction): # image index, image inference
  646. # Apply constraints
  647. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  648. x = x[xc[xi]] # confidence 筛选出大于conf_thres阈值的confidence
  649. # Cat apriori labels if autolabelling
  650. if labels and len(labels[xi]):
  651. lb = labels[xi]
  652. v = torch.zeros((len(lb), nc + 5), device=x.device)
  653. v[:, :4] = lb[:, 1:5] # box
  654. v[:, 4] = 1.0 # conf
  655. v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
  656. x = torch.cat((x, v), 0)
  657. # If none remain process next image
  658. if not x.shape[0]: #如果没有符合要求的预测框则直接跳到下一张图片
  659. continue
  660. # Compute conf
  661. x[:, 5:] *= x[:, 4:5] # 计算最终的置信度 conf = obj_conf * cls_conf
  662. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  663. box = xywh2xyxy(x[:, :4]) # yolo格式转VOC
  664. # Detections matrix nx6 (xyxy, conf, cls)
  665. if multi_label: #预测框多标签类别检测
  666. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  667. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  668. else: # best class only 单标签类别检测
  669. conf, j = x[:, 5:].max(1, keepdim=True) #[52, 80] 将每一行5:85列中挑选出最大值作为预测类别
  670. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] #再次用conf_thres进行过滤 x[xmin,ymin,xmax,ymax,confidence,class]
  671. # Filter by class
  672. if classes is not None:
  673. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  674. # Apply finite constraint
  675. # if not torch.isfinite(x).all():
  676. # x = x[torch.isfinite(x).all(1)]
  677. # Check shape
  678. n = x.shape[0] # number of boxes 检测出来的目标框个数
  679. if not n: # no boxes
  680. continue
  681. elif n > max_nms: # excess boxes 检测数量是否超过了最大检测框数
  682. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  683. # Batched NMS
  684. # yolov5在对每个类别做nms的时候,按照我们大家之前的做法,是写个循环遍历每个类别去分别做nms并保留,而yolov5是直接一步到位。
  685. # 源码里面针对不同类别的x1,x2,y1,y2坐标分别加上了一个不同量级的数值,也就是我们上面说的7680这个数值的倍数。
  686. # 比如类别1的所有坐标加上7680*1,类别2的所有坐标加上7680*2,类别3的所有坐标加上7680*3。。
  687. # 这种方式巧妙地把每一个类别的坐标都归到不同的量级,做nms的时候就不用考虑类别了,因为不同类别之间压根就不可能在附近,最后nms在把坐标直接还原就行了。
  688. # 7680像素前面讲了也是yoloV5框架所能检测到的最大物体尺寸。不得不说yolov5不愧是集大成者,源码里面能学到很多代码的小技巧。
  689. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes 类别序号乘以7680最大宽高 classes
  690. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores boxes在所有的坐标上加上了7680*类别序号 scores类别概率 boxes
  691. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS 根据iou_thres 和 scores对比后删除比iou_thres值小的bouding boxes并对过滤后的bouding boxes索引进行降序排列
  692. if i.shape[0] > max_det: # limit detections
  693. i = i[:max_det]
  694. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  695. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  696. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  697. weights = iou * scores[None] # box weights
  698. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  699. if redundant:
  700. i = i[iou.sum(1) > 1] # require redundancy
  701. output[xi] = x[i]
  702. if (time.time() - t) > time_limit:
  703. LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')
  704. break # time limit exceeded
  705. return output
  706. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  707. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  708. x = torch.load(f, map_location=torch.device('cpu'))
  709. if x.get('ema'):
  710. x['model'] = x['ema'] # replace model with ema
  711. for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys
  712. x[k] = None
  713. x['epoch'] = -1
  714. x['model'].half() # to FP16
  715. for p in x['model'].parameters():
  716. p.requires_grad = False
  717. torch.save(x, s or f)
  718. mb = os.path.getsize(s or f) / 1E6 # filesize
  719. LOGGER.info(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  720. def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
  721. evolve_csv = save_dir / 'evolve.csv'
  722. evolve_yaml = save_dir / 'hyp_evolve.yaml'
  723. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
  724. 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  725. keys = tuple(x.strip() for x in keys)
  726. vals = results + tuple(hyp.values())
  727. n = len(keys)
  728. # Download (optional)
  729. if bucket:
  730. url = f'gs://{bucket}/evolve.csv'
  731. if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
  732. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  733. # Log to evolve.csv
  734. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  735. with open(evolve_csv, 'a') as f:
  736. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  737. # Save yamlF
  738. with open(evolve_yaml, 'w') as f:
  739. data = pd.read_csv(evolve_csv)
  740. data = data.rename(columns=lambda x: x.strip()) # strip keys
  741. i = np.argmax(fitness(data.values[:, :4])) #
  742. generations = len(data)
  743. f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
  744. f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
  745. '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  746. yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
  747. # Print to screen
  748. LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
  749. ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
  750. for x in vals) + '\n\n')
  751. if bucket:
  752. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  753. def apply_classifier(x, model, img, im0):
  754. # Apply a second stage classifier to YOLO outputs
  755. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  756. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  757. for i, d in enumerate(x): # per image
  758. if d is not None and len(d):
  759. d = d.clone()
  760. # Reshape and pad cutouts
  761. b = xyxy2xywh(d[:, :4]) # boxes
  762. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  763. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  764. d[:, :4] = xywh2xyxy(b).long()
  765. # Rescale boxes from img_size to im0 size
  766. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  767. # Classes
  768. pred_cls1 = d[:, 5].long()
  769. ims = []
  770. for j, a in enumerate(d): # per item
  771. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  772. im = cv2.resize(cutout, (224, 224)) # BGR
  773. # cv2.imwrite('example%i.jpg' % j, cutout)
  774. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  775. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  776. im /= 255 # 0 - 255 to 0.0 - 1.0
  777. ims.append(im)
  778. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  779. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  780. return x
  781. def increment_path(path, exist_ok=False, sep='', mkdir=False):# 生成增量文件夹路径
  782. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  783. path = Path(path) # os-agnostic
  784. if path.exists() and not exist_ok:
  785. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  786. dirs = glob.glob(f"{path}{sep}*") # similar paths
  787. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  788. i = [int(m.groups()[0]) for m in matches if m] # indices
  789. n = max(i) + 1 if i else 2 # increment number
  790. path = Path(f"{path}{sep}{n}{suffix}") # increment path
  791. if mkdir:
  792. path.mkdir(parents=True, exist_ok=True) # make directory
  793. return path
  794. # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
  795. imshow_ = cv2.imshow # copy to avoid recursion errors
  796. def imread(path, flags=cv2.IMREAD_COLOR):
  797. return cv2.imdecode(np.fromfile(path, np.uint8), flags)
  798. def imwrite(path, im):
  799. try:
  800. cv2.imencode(Path(path).suffix, im)[1].tofile(path)
  801. return True
  802. except Exception:
  803. return False
  804. def imshow(path, im):
  805. imshow_(path.encode('unicode_escape').decode(), im)
  806. cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
  807. # Variables ------------------------------------------------------------------------------------------------------------
  808. NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm