# YOLOv5 🚀 by Ultralytics, GPL-3.0 license """ General utils """ import contextlib import glob import inspect import logging import math import os import platform import random import re import shutil import signal import time import urllib from datetime import datetime from itertools import repeat from multiprocessing.pool import ThreadPool from pathlib import Path from subprocess import check_output from typing import Optional from zipfile import ZipFile import cv2 import numpy as np import pandas as pd import pkg_resources as pkg import torch import torchvision import yaml from utils.downloads import gsutil_getsize from utils.metrics import box_iou, fitness # Settings FILE = Path(__file__).resolve() ROOT = FILE.parents[1] # YOLOv5 root directory DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf torch.set_printoptions(linewidth=320, precision=5, profile='long') np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 pd.options.display.max_columns = 10 cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) 使用cv2单线程 os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads os.environ['OMP_NUM_THREADS'] = str(NUM_THREADS) # OpenMP max threads (PyTorch and SciPy) def is_kaggle(): # Is environment a Kaggle Notebook? try: assert os.environ.get('PWD') == '/kaggle/working' assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com' return True except AssertionError: return False def is_writeable(dir, test=False): # Return True if directory has write permissions, test opening a file with write permissions if test=True if test: # method 1 file = Path(dir) / 'tmp.txt' try: with open(file, 'w'): # open file with write permissions pass file.unlink() # remove file return True except OSError: return False else: # method 2 return os.access(dir, os.R_OK) # possible issues on Windows def set_logging(name=None, verbose=VERBOSE): # Sets level and returns logger if is_kaggle(): for h in logging.root.handlers: logging.root.removeHandler(h) # remove all handlers associated with the root logger object rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings level = logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING log = logging.getLogger(name) log.setLevel(level) handler = logging.StreamHandler() handler.setFormatter(logging.Formatter("%(message)s")) handler.setLevel(level) log.addHandler(handler) set_logging() # run before defining LOGGER LOGGER = logging.getLogger("yolov5") # define globally (used in train.py, val.py, detect.py, etc.) def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'): # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required. env = os.getenv(env_var) if env: path = Path(env) # use environment variable else: cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable path.mkdir(exist_ok=True) # make if required return path CONFIG_DIR = user_config_dir() # Ultralytics settings dir class Profile(contextlib.ContextDecorator): # Usage: @Profile() decorator or 'with Profile():' context manager def __enter__(self): self.start = time.time() def __exit__(self, type, value, traceback): print(f'Profile results: {time.time() - self.start:.5f}s') class Timeout(contextlib.ContextDecorator): # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True): self.seconds = int(seconds) self.timeout_message = timeout_msg self.suppress = bool(suppress_timeout_errors) def _timeout_handler(self, signum, frame): raise TimeoutError(self.timeout_message) def __enter__(self): if platform.system() != 'Windows': # not supported on Windows signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM signal.alarm(self.seconds) # start countdown for SIGALRM to be raised def __exit__(self, exc_type, exc_val, exc_tb): if platform.system() != 'Windows': signal.alarm(0) # Cancel SIGALRM if it's scheduled if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError return True class WorkingDirectory(contextlib.ContextDecorator): # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager def __init__(self, new_dir): self.dir = new_dir # new dir self.cwd = Path.cwd().resolve() # current dir def __enter__(self): os.chdir(self.dir) def __exit__(self, exc_type, exc_val, exc_tb): os.chdir(self.cwd) def try_except(func): # try-except function. Usage: @try_except decorator def handler(*args, **kwargs): try: func(*args, **kwargs) except Exception as e: print(e) return handler def methods(instance): # Get class/instance methods return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False): # Print function arguments (optional args dict) x = inspect.currentframe().f_back # previous frame file, _, fcn, _, _ = inspect.getframeinfo(x) if args is None: # get args automatically args, _, _, frm = inspect.getargvalues(x) args = {k: v for k, v in frm.items() if k in args} s = (f'{Path(file).stem}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '') LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items())) def init_seeds(seed=0): # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible import torch.backends.cudnn as cudnn random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False) def intersect_dicts(da, db, exclude=()): # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} def get_latest_run(search_dir='.'): # Return path to most recent 'last.pt' in /runs (i.e. to --resume from) last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) return max(last_list, key=os.path.getctime) if last_list else '' def is_docker(): # Is environment a Docker container? return Path('/workspace').exists() # or Path('/.dockerenv').exists() def is_colab(): # Is environment a Google Colab instance? try: import google.colab return True except ImportError: return False def is_pip(): # Is file in a pip package? return 'site-packages' in Path(__file__).resolve().parts def is_ascii(s=''): # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7) s = str(s) # convert list, tuple, None, etc. to str return len(s.encode().decode('ascii', 'ignore')) == len(s) def is_chinese(s='人工智能'): # Is string composed of any Chinese characters? return True if re.search('[\u4e00-\u9fff]', str(s)) else False def emojis(str=''): # Return platform-dependent emoji-safe version of string return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str def file_age(path=__file__): # Return days since last file update dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta return dt.days # + dt.seconds / 86400 # fractional days def file_update_date(path=__file__): # Return human-readable file modification date, i.e. '2021-3-26' t = datetime.fromtimestamp(Path(path).stat().st_mtime) return f'{t.year}-{t.month}-{t.day}' def file_size(path): # Return file/dir size (MB) mb = 1 << 20 # bytes to MiB (1024 ** 2) path = Path(path) if path.is_file(): return path.stat().st_size / mb elif path.is_dir(): return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb else: return 0.0 def check_online(): # Check internet connectivity import socket try: socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility return True except OSError: return False def git_describe(path=ROOT): # path must be a directory # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe try: assert (Path(path) / '.git').is_dir() return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1] except Exception: return '' @try_except @WorkingDirectory(ROOT) def check_git_status(): # Recommend 'git pull' if code is out of date msg = ', for updates see https://github.com/ultralytics/yolov5' s = colorstr('github: ') # string assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg assert not is_docker(), s + 'skipping check (Docker image)' + msg assert check_online(), s + 'skipping check (offline)' + msg cmd = 'git fetch && git config --get remote.origin.url' url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind if n > 0: s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update." else: s += f'up to date with {url} ✅' LOGGER.info(emojis(s)) # emoji-safe def check_python(minimum='3.7.0'): # Check current python version vs. required python version check_version(platform.python_version(), minimum, name='Python ', hard=True) def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False): # Check version vs. required version current, minimum = (pkg.parse_version(x) for x in (current, minimum)) result = (current == minimum) if pinned else (current >= minimum) # bool s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string if hard: assert result, s # assert min requirements met if verbose and not result: LOGGER.warning(s) return result @try_except def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()): # Check installed dependencies meet requirements (pass *.txt file or list of packages) prefix = colorstr('red', 'bold', 'requirements:') check_python() # check python version if isinstance(requirements, (str, Path)): # requirements.txt file file = Path(requirements) assert file.exists(), f"{prefix} {file.resolve()} not found, check failed." with file.open() as f: requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude] else: # list or tuple of packages requirements = [x for x in requirements if x not in exclude] n = 0 # number of packages updates for i, r in enumerate(requirements): try: pkg.require(r) except Exception: # DistributionNotFound or VersionConflict if requirements not met s = f"{prefix} {r} not found and is required by YOLOv5" if install and AUTOINSTALL: # check environment variable LOGGER.info(f"{s}, attempting auto-update...") try: assert check_online(), f"'pip install {r}' skipped (offline)" LOGGER.info(check_output(f"pip install '{r}' {cmds[i] if cmds else ''}", shell=True).decode()) n += 1 except Exception as e: LOGGER.warning(f'{prefix} {e}') else: LOGGER.info(f'{s}. Please install and rerun your command.') if n: # if packages updated source = file.resolve() if 'file' in locals() else requirements s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \ f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" LOGGER.info(emojis(s)) def check_img_size(imgsz, s=32, floor=0): #检查图像尺寸是否是s的整数倍;不是则调整图像尺寸 # Verify image size is a multiple of stride s in each dimension if isinstance(imgsz, int): # integer i.e. img_size=640 new_size = max(make_divisible(imgsz, int(s)), floor) # 返回能被s整除的最近的整数 else: # list i.e. img_size=[640, 480] imgsz = list(imgsz) # convert to list if tuple new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz] if new_size != imgsz: LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}') return new_size def check_imshow(): # Check if environment supports image displays try: assert not is_docker(), 'cv2.imshow() is disabled in Docker environments' assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments' cv2.imshow('test', np.zeros((1, 1, 3))) cv2.waitKey(1) cv2.destroyAllWindows() cv2.waitKey(1) return True except Exception as e: LOGGER.warning(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}') return False def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''): # Check file(s) for acceptable suffix if file and suffix: if isinstance(suffix, str): suffix = [suffix] for f in file if isinstance(file, (list, tuple)) else [file]: s = Path(f).suffix.lower() # file suffix if len(s): assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}" def check_yaml(file, suffix=('.yaml', '.yml')): # Search/download YAML file (if necessary) and return path, checking suffix return check_file(file, suffix) def check_file(file, suffix=''): # Search/download file (if necessary) and return path check_suffix(file, suffix) # optional file = str(file) # convert to str() if Path(file).is_file() or file == '': # exists return file elif file.startswith(('http:/', 'https:/')): # download url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/ file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth if Path(file).is_file(): LOGGER.info(f'Found {url} locally at {file}') # file already exists else: LOGGER.info(f'Downloading {url} to {file}...') torch.hub.download_url_to_file(url, file) assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check return file else: # search files = [] for d in 'data', 'models', 'utils': # search directories files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file assert len(files), f'File not found: {file}' # assert file was found assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique return files[0] # return file def check_font(font=FONT, progress=False): # Download font to CONFIG_DIR if necessary font = Path(font) file = CONFIG_DIR / font.name if not font.exists() and not file.exists(): url = "https://ultralytics.com/assets/" + font.name LOGGER.info(f'Downloading {url} to {file}...') torch.hub.download_url_to_file(url, str(file), progress=progress) def check_dataset(data, autodownload=True): # Download and/or unzip dataset if not found locally # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip # Download (optional) extract_dir = '' if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1) data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml')) extract_dir, autodownload = data.parent, False # Read yaml (optional) if isinstance(data, (str, Path)): with open(data, errors='ignore') as f: data = yaml.safe_load(f) # dictionary # Resolve paths path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.' if not path.is_absolute(): path = (ROOT / path).resolve() for k in 'train', 'val', 'test': if data.get(k): # prepend path data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]] # Parse yaml assert 'nc' in data, "Dataset 'nc' key missing." if 'names' not in data: data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download')) if val: val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path if not all(x.exists() for x in val): LOGGER.info(emojis('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()])) if s and autodownload: # download script t = time.time() root = path.parent if 'path' in data else '..' # unzip directory i.e. '../' if s.startswith('http') and s.endswith('.zip'): # URL f = Path(s).name # filename LOGGER.info(f'Downloading {s} to {f}...') torch.hub.download_url_to_file(s, f) Path(root).mkdir(parents=True, exist_ok=True) # create root ZipFile(f).extractall(path=root) # unzip Path(f).unlink() # remove zip r = None # success elif s.startswith('bash '): # bash script LOGGER.info(f'Running {s} ...') r = os.system(s) else: # python script r = exec(s, {'yaml': data}) # return None dt = f'({round(time.time() - t, 1)}s)' s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌" LOGGER.info(emojis(f"Dataset download {s}")) else: raise Exception(emojis('Dataset not found ❌')) return data # dictionary def url2file(url): # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/ file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth return file def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3): # Multi-threaded file download and unzip function, used in data.yaml for autodownload def download_one(url, dir): # Download 1 file success = True f = dir / Path(url).name # filename if Path(url).is_file(): # exists in current path Path(url).rename(f) # move to dir elif not f.exists(): LOGGER.info(f'Downloading {url} to {f}...') for i in range(retry + 1): if curl: s = 'sS' if threads > 1 else '' # silent r = os.system(f"curl -{s}L '{url}' -o '{f}' --retry 9 -C -") # curl download success = r == 0 else: torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download success = f.is_file() if success: break elif i < retry: LOGGER.warning(f'Download failure, retrying {i + 1}/{retry} {url}...') else: LOGGER.warning(f'Failed to download {url}...') if unzip and success and f.suffix in ('.zip', '.gz'): LOGGER.info(f'Unzipping {f}...') if f.suffix == '.zip': ZipFile(f).extractall(path=dir) # unzip elif f.suffix == '.gz': os.system(f'tar xfz {f} --directory {f.parent}') # unzip if delete: f.unlink() # remove zip dir = Path(dir) dir.mkdir(parents=True, exist_ok=True) # make directory if threads > 1: pool = ThreadPool(threads) pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded pool.close() pool.join() else: for u in [url] if isinstance(url, (str, Path)) else url: download_one(u, dir) def make_divisible(x, divisor): # 返回能被除数整除且最接近x的整数 # Returns nearest x divisible by divisor if isinstance(divisor, torch.Tensor): divisor = int(divisor.max()) # to int return math.ceil(x / divisor) * divisor def clean_str(s): # Cleans a string by replacing special characters with underscore _ return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) def one_cycle(y1=0.0, y2=1.0, steps=100): # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1 def colorstr(*input): # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world') *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string colors = { 'black': '\033[30m', # basic colors 'red': '\033[31m', 'green': '\033[32m', 'yellow': '\033[33m', 'blue': '\033[34m', 'magenta': '\033[35m', 'cyan': '\033[36m', 'white': '\033[37m', 'bright_black': '\033[90m', # bright colors 'bright_red': '\033[91m', 'bright_green': '\033[92m', 'bright_yellow': '\033[93m', 'bright_blue': '\033[94m', 'bright_magenta': '\033[95m', 'bright_cyan': '\033[96m', 'bright_white': '\033[97m', 'end': '\033[0m', # misc 'bold': '\033[1m', 'underline': '\033[4m'} return ''.join(colors[x] for x in args) + f'{string}' + colors['end'] def labels_to_class_weights(labels, nc=80): # Get class weights (inverse frequency) from training labels if labels[0] is None: # no labels loaded return torch.Tensor() labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO classes = labels[:, 0].astype(np.int) # labels = [class xywh] weights = np.bincount(classes, minlength=nc) # occurrences per class # Prepend gridpoint count (for uCE training) # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start weights[weights == 0] = 1 # replace empty bins with 1 weights = 1 / weights # number of targets per class weights /= weights.sum() # normalize return torch.from_numpy(weights) def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)): # Produces image weights based on class_weights and image contents class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels]) # 统计每个类别的样本数目 image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1) # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample return image_weights def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper) # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/ # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n') # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n') # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet x = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] return x def xyxy2xywh(x): # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center y[:, 2] = x[:, 2] - x[:, 0] # width y[:, 3] = x[:, 3] - x[:, 1] # height return y def xywh2xyxy(x): # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y return y def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y return y def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right if clip: clip_coords(x, (h - eps, w - eps)) # warning: inplace clip y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center y[:, 2] = (x[:, 2] - x[:, 0]) / w # width y[:, 3] = (x[:, 3] - x[:, 1]) / h # height return y def xyn2xy(x, w=640, h=640, padw=0, padh=0): # Convert normalized segments into pixel segments, shape (n,2) y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 0] = w * x[:, 0] + padw # top left x y[:, 1] = h * x[:, 1] + padh # top left y return y def segment2box(segment, width=640, height=640): # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy) x, y = segment.T # segment xy inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height) x, y, = x[inside], y[inside] return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy def segments2boxes(segments): # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh) boxes = [] for s in segments: x, y = s.T # segment xy boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy return xyxy2xywh(np.array(boxes)) # cls, xywh def resample_segments(segments, n=1000): # Up-sample an (n,2) segment for i, s in enumerate(segments): x = np.linspace(0, len(s) - 1, n) xp = np.arange(len(s)) segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy return segments def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): # Rescale coords (xyxy) from img1_shape to img0_shape if ratio_pad is None: # calculate from img0_shape gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding else: gain = ratio_pad[0][0] pad = ratio_pad[1] coords[:, [0, 2]] -= pad[0] # x padding coords[:, [1, 3]] -= pad[1] # y padding coords[:, :4] /= gain clip_coords(coords, img0_shape) return coords def clip_coords(boxes, shape): # Clip bounding xyxy bounding boxes to image shape (height, width) if isinstance(boxes, torch.Tensor): # faster individually boxes[:, 0].clamp_(0, shape[1]) # x1 boxes[:, 1].clamp_(0, shape[0]) # y1 boxes[:, 2].clamp_(0, shape[1]) # x2 boxes[:, 3].clamp_(0, shape[0]) # y2 else: # np.array (faster grouped) boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2 boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2 # 非极大抑制算法 def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, labels=(), max_det=300): """ :param prediction: 网络的预测结果 :type prediction: :param conf_thres: 置信度阈值 :type conf_thres: :param iou_thres: IOU阈值 :type iou_thres: :param classes: 是否只保留特定的类别 :type classes: :param agnostic: 进行nms时是否也去西湖不同类别之间的框 :type agnostic: :param multi_label: 多标签框预测 :type multi_label: :param labels: 类别 :type labels: :param max_det: 保留最大检测框数 :type max_det: :return: :rtype: """ """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes Returns: list of detections, on (n,6) tensor per image [xyxy, conf, cls] """ bs = prediction.shape[0] # batch size nc = prediction.shape[2] - 5 # number of classes xc = prediction[..., 4] > conf_thres # candidates prediction维度[1, 18900, 85] xc维度[1, 18900] # Checks assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0' # Settings # min_wh = 2 # (pixels) minimum box width and height 最小宽高 max_wh = 7680 # (pixels) maximum box width and height 最大宽高 max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() 最大检测框数 time_limit = 0.1 + 0.03 * bs # seconds to quit after 检测最大时长 redundant = True # require redundant detections multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) 预测框的多标签类别,默认为False merge = False # use merge-NMS t = time.time() output = [torch.zeros((0, 6), device=prediction.device)] * bs for xi, x in enumerate(prediction): # image index, image inference # Apply constraints # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height x = x[xc[xi]] # confidence 筛选出大于conf_thres阈值的confidence # Cat apriori labels if autolabelling if labels and len(labels[xi]): lb = labels[xi] v = torch.zeros((len(lb), nc + 5), device=x.device) v[:, :4] = lb[:, 1:5] # box v[:, 4] = 1.0 # conf v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls x = torch.cat((x, v), 0) # If none remain process next image if not x.shape[0]: #如果没有符合要求的预测框则直接跳到下一张图片 continue # Compute conf x[:, 5:] *= x[:, 4:5] # 计算最终的置信度 conf = obj_conf * cls_conf # Box (center x, center y, width, height) to (x1, y1, x2, y2) box = xywh2xyxy(x[:, :4]) # yolo格式转VOC # Detections matrix nx6 (xyxy, conf, cls) if multi_label: #预测框多标签类别检测 i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) else: # best class only 单标签类别检测 conf, j = x[:, 5:].max(1, keepdim=True) #[52, 80] 将每一行5:85列中挑选出最大值作为预测类别 x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] #再次用conf_thres进行过滤 x[xmin,ymin,xmax,ymax,confidence,class] # Filter by class if classes is not None: x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] # Apply finite constraint # if not torch.isfinite(x).all(): # x = x[torch.isfinite(x).all(1)] # Check shape n = x.shape[0] # number of boxes 检测出来的目标框个数 if not n: # no boxes continue elif n > max_nms: # excess boxes 检测数量是否超过了最大检测框数 x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence # Batched NMS # yolov5在对每个类别做nms的时候,按照我们大家之前的做法,是写个循环遍历每个类别去分别做nms并保留,而yolov5是直接一步到位。 # 源码里面针对不同类别的x1,x2,y1,y2坐标分别加上了一个不同量级的数值,也就是我们上面说的7680这个数值的倍数。 # 比如类别1的所有坐标加上7680*1,类别2的所有坐标加上7680*2,类别3的所有坐标加上7680*3。。 # 这种方式巧妙地把每一个类别的坐标都归到不同的量级,做nms的时候就不用考虑类别了,因为不同类别之间压根就不可能在附近,最后nms在把坐标直接还原就行了。 # 7680像素前面讲了也是yoloV5框架所能检测到的最大物体尺寸。不得不说yolov5不愧是集大成者,源码里面能学到很多代码的小技巧。 c = x[:, 5:6] * (0 if agnostic else max_wh) # classes 类别序号乘以7680最大宽高 classes boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores boxes在所有的坐标上加上了7680*类别序号 scores类别概率 boxes i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS 根据iou_thres 和 scores对比后删除比iou_thres值小的bouding boxes并对过滤后的bouding boxes索引进行降序排列 if i.shape[0] > max_det: # limit detections i = i[:max_det] if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix weights = iou * scores[None] # box weights x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes if redundant: i = i[iou.sum(1) > 1] # require redundancy output[xi] = x[i] if (time.time() - t) > time_limit: LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded') break # time limit exceeded return output def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer() # Strip optimizer from 'f' to finalize training, optionally save as 's' x = torch.load(f, map_location=torch.device('cpu')) if x.get('ema'): x['model'] = x['ema'] # replace model with ema for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys x[k] = None x['epoch'] = -1 x['model'].half() # to FP16 for p in x['model'].parameters(): p.requires_grad = False torch.save(x, s or f) mb = os.path.getsize(s or f) / 1E6 # filesize LOGGER.info(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB") def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')): evolve_csv = save_dir / 'evolve.csv' evolve_yaml = save_dir / 'hyp_evolve.yaml' keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps] keys = tuple(x.strip() for x in keys) vals = results + tuple(hyp.values()) n = len(keys) # Download (optional) if bucket: url = f'gs://{bucket}/evolve.csv' if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0): os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local # Log to evolve.csv s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header with open(evolve_csv, 'a') as f: f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n') # Save yamlF with open(evolve_yaml, 'w') as f: data = pd.read_csv(evolve_csv) data = data.rename(columns=lambda x: x.strip()) # strip keys i = np.argmax(fitness(data.values[:, :4])) # generations = len(data) f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' + f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n') yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False) # Print to screen LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix + ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}' for x in vals) + '\n\n') if bucket: os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload def apply_classifier(x, model, img, im0): # Apply a second stage classifier to YOLO outputs # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval() im0 = [im0] if isinstance(im0, np.ndarray) else im0 for i, d in enumerate(x): # per image if d is not None and len(d): d = d.clone() # Reshape and pad cutouts b = xyxy2xywh(d[:, :4]) # boxes b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad d[:, :4] = xywh2xyxy(b).long() # Rescale boxes from img_size to im0 size scale_coords(img.shape[2:], d[:, :4], im0[i].shape) # Classes pred_cls1 = d[:, 5].long() ims = [] for j, a in enumerate(d): # per item cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])] im = cv2.resize(cutout, (224, 224)) # BGR # cv2.imwrite('example%i.jpg' % j, cutout) im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32 im /= 255 # 0 - 255 to 0.0 - 1.0 ims.append(im) pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections return x def increment_path(path, exist_ok=False, sep='', mkdir=False):# 生成增量文件夹路径 # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. path = Path(path) # os-agnostic if path.exists() and not exist_ok: path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '') dirs = glob.glob(f"{path}{sep}*") # similar paths matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs] i = [int(m.groups()[0]) for m in matches if m] # indices n = max(i) + 1 if i else 2 # increment number path = Path(f"{path}{sep}{n}{suffix}") # increment path if mkdir: path.mkdir(parents=True, exist_ok=True) # make directory return path # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------ imshow_ = cv2.imshow # copy to avoid recursion errors def imread(path, flags=cv2.IMREAD_COLOR): return cv2.imdecode(np.fromfile(path, np.uint8), flags) def imwrite(path, im): try: cv2.imencode(Path(path).suffix, im)[1].tofile(path) return True except Exception: return False def imshow(path, im): imshow_(path.encode('unicode_escape').decode(), im) cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine # Variables ------------------------------------------------------------------------------------------------------------ NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm