common.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Common modules
  4. """
  5. import json
  6. import math
  7. import platform
  8. import warnings
  9. from collections import OrderedDict, namedtuple
  10. from copy import copy
  11. from pathlib import Path
  12. import cv2
  13. import numpy as np
  14. import pandas as pd
  15. import requests
  16. import torch
  17. import torch.nn as nn
  18. import yaml
  19. from PIL import Image
  20. from torch.cuda import amp
  21. from utils.datasets import exif_transpose, letterbox
  22. from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
  23. make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
  24. from utils.plots import Annotator, colors, save_one_box
  25. from utils.torch_utils import copy_attr, time_sync
  26. # 为same卷积或same池化自动扩充
  27. def autopad(k, p=None): # kernel, padding
  28. """
  29. 用于Conv函数和Classify函数,根据卷积核大小k自动计算卷积核和padding数
  30. v5中只有两种卷积:
  31. 1.下采样卷积:conv3*3 s=2 p=k//2=1
  32. 2.feature size不变的卷积:conv1*1 s=1 p=k//2=1
  33. :param k: 卷积核的kernel_size
  34. :type k:
  35. :param p:自动计算的pad值
  36. :type p:
  37. :return:
  38. :rtype:
  39. """
  40. # Pad to 'same'
  41. if p is None:
  42. p = k // 2 if isinstance(k, int) else (x // 2 for x in k) # auto-pad
  43. return p
  44. class Conv(nn.Module):
  45. # Standard convolution 标准卷积:conv+BN+SiLU
  46. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  47. """
  48. 在Focus、Bottleneck、BottleneckCSP、C3、SPP、DWConv、TransformerBloc等模块中调用的基础组件
  49. :param c1:输入的channel值
  50. :type c1:
  51. :param c2:输出的channel值
  52. :type c2:
  53. :param k:卷积的kernel_size
  54. :type k:
  55. :param s:卷积的stride
  56. :type s:
  57. :param p:卷积的padding数,可以通过autopad自行计算padding数
  58. :type p:
  59. :param g:卷积的groups数 一般等于1为普通卷积,大于1就是深度可分离卷积
  60. :type g:
  61. :param act:激活函数类型 True就是SiLU
  62. :type act:
  63. """
  64. super().__init__()
  65. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  66. self.bn = nn.BatchNorm2d(c2)
  67. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
  68. def forward(self, x): # 网络的执行顺序是根据 forward 函数决定的
  69. return self.act(self.bn(self.conv(x)))
  70. def forward_fuse(self, x):
  71. """
  72. 用于Model类的fuse函数
  73. 相较于forward函数去掉了BN层,加速推理,一般用于测试/验证阶段
  74. :param x:
  75. :type x:
  76. :return:
  77. :rtype:
  78. """
  79. return self.act(self.conv(x))
  80. class DWConv(Conv):
  81. # Depth-wise convolution class
  82. def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  83. super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
  84. class TransformerLayer(nn.Module):
  85. # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
  86. def __init__(self, c, num_heads):
  87. super().__init__()
  88. self.q = nn.Linear(c, c, bias=False)
  89. self.k = nn.Linear(c, c, bias=False)
  90. self.v = nn.Linear(c, c, bias=False)
  91. self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
  92. self.fc1 = nn.Linear(c, c, bias=False)
  93. self.fc2 = nn.Linear(c, c, bias=False)
  94. def forward(self, x):
  95. x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
  96. x = self.fc2(self.fc1(x)) + x
  97. return x
  98. class TransformerBlock(nn.Module):
  99. # Vision Transformer https://arxiv.org/abs/2010.11929
  100. def __init__(self, c1, c2, num_heads, num_layers):
  101. super().__init__()
  102. self.conv = None
  103. if c1 != c2:
  104. self.conv = Conv(c1, c2)
  105. self.linear = nn.Linear(c2, c2) # learnable position embedding
  106. self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
  107. self.c2 = c2
  108. def forward(self, x):
  109. if self.conv is not None:
  110. x = self.conv(x)
  111. b, _, w, h = x.shape
  112. p = x.flatten(2).permute(2, 0, 1)
  113. return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
  114. class Bottleneck(nn.Module):
  115. # Standard bottleneck
  116. """
  117. 由1*1conv、3*3conv、残差块组成
  118. """
  119. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  120. """
  121. 在BottleneckCSP、C3、parse_model中调用
  122. 组件分为两种情况,当shortcut为True时,bottleneck需要在经过1*1卷积和3*3卷积后在经过shortcut
  123. 当shortcut为False时,bottleneck只需要经过1*1卷积和3*3卷积即可
  124. :param c1:输入channel
  125. :type c1:
  126. :param c2:输出channel
  127. :type c2:
  128. :param shortcut:是否进行shortcut 默认为True
  129. :type shortcut:
  130. :param g: 卷积的groups数 等于1普通卷积 大于1深度可分离卷积
  131. :type g:
  132. :param e:膨胀系数
  133. :type e:
  134. """
  135. super().__init__()
  136. c_ = int(c2 * e) # hidden channels 中间层的channel数
  137. self.cv1 = Conv(c1, c_, 1, 1) # 第一层卷积输出的channel数为c_
  138. self.cv2 = Conv(c_, c2, 3, 1, g=g)# 第二层卷积输入的channel数为c_
  139. self.add = shortcut and c1 == c2
  140. def forward(self, x):
  141. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  142. class BottleneckCSP(nn.Module):
  143. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  144. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  145. """
  146. 该组件由bottleneck模块和CSP模块组成,此模块与C3模块等效。
  147. :param c1:输入channel
  148. :type c1:
  149. :param c2:输出channel
  150. :type c2:
  151. :param n:有n个bottleneck
  152. :type n:
  153. :param shortcut:bottleneck中是shortcut,默认为True
  154. :type shortcut:
  155. :param g: bottleneck中的groups 等于1,普通卷积 大于1,深度可分离卷积
  156. :type g:
  157. :param e:bottleneck中的膨胀系数
  158. :type e:
  159. """
  160. super().__init__()
  161. c_ = int(c2 * e) # hidden channels
  162. self.cv1 = Conv(c1, c_, 1, 1) #Conv+BN+SiLU
  163. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  164. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  165. self.cv4 = Conv(2 * c_, c2, 1, 1)
  166. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  167. self.act = nn.SiLU()
  168. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) #叠加n次bottleneck
  169. def forward(self, x):
  170. y1 = self.cv3(self.m(self.cv1(x)))
  171. y2 = self.cv2(x)
  172. return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
  173. class C3(nn.Module):
  174. # CSP Bottleneck with 3 convolutions
  175. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  176. """
  177. 简化版的bottleneckCSP模块,除了bottleneck部分整个结构只有3个卷积,可以减少参数
  178. :param c1: 输入channel
  179. :type c1:
  180. :param c2: 输出channel
  181. :type c2:
  182. :param n: 有n个bottleneck
  183. :type n:
  184. :param shortcut: bottleneck中是否有shortcut,默认为True
  185. :type shortcut:
  186. :param g: bottleneck中的groups 等于1,普通卷积 大于1,深度可分离卷积
  187. :type g:
  188. :param e: bottleneck中的膨胀系数
  189. :type e:
  190. """
  191. super().__init__()
  192. c_ = int(c2 * e) # hidden channels
  193. self.cv1 = Conv(c1, c_, 1, 1)
  194. self.cv2 = Conv(c1, c_, 1, 1)
  195. self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
  196. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  197. # self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)))
  198. def forward(self, x):
  199. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
  200. class C3TR(C3):
  201. # C3 module with TransformerBlock()
  202. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  203. super().__init__(c1, c2, n, shortcut, g, e)
  204. c_ = int(c2 * e)
  205. self.m = TransformerBlock(c_, c_, 4, n)
  206. class C3SPP(C3):
  207. # C3 module with SPP()
  208. def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
  209. super().__init__(c1, c2, n, shortcut, g, e)
  210. c_ = int(c2 * e)
  211. self.m = SPP(c_, c_, k)
  212. class C3Ghost(C3):
  213. # C3 module with GhostBottleneck()
  214. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  215. super().__init__(c1, c2, n, shortcut, g, e)
  216. c_ = int(c2 * e) # hidden channels
  217. self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
  218. class SPP(nn.Module):
  219. # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
  220. def __init__(self, c1, c2, k=(5, 9, 13)):
  221. """
  222. 空间金字塔池化
  223. :param c1: 输入channel
  224. :type c1:
  225. :param c2: 输出channel
  226. :type c2:
  227. :param k: 保存着三个maxpool卷积的kernel_size。默认是(5, 9, 13)
  228. :type k:
  229. """
  230. super().__init__()
  231. c_ = c1 // 2 # hidden channels
  232. self.cv1 = Conv(c1, c_, 1, 1) # 第一层卷积
  233. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) # 最后一层卷积
  234. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) # 中间的maxpool层
  235. def forward(self, x):
  236. x = self.cv1(x)
  237. with warnings.catch_warnings():
  238. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  239. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  240. class SPPF(nn.Module):
  241. # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
  242. def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
  243. """
  244. SPP的升级改进版,将5*5,9*9,13*13三个amxpool并行输出的结果改成了3个5*5的maxpool串行输出的结果。结果是提升了计算速度
  245. :param c1: 输入channel
  246. :type c1:
  247. :param c2: 输出channel
  248. :type c2:
  249. :param k: 卷积的kernel_size
  250. :type k:
  251. """
  252. super().__init__()
  253. c_ = c1 // 2 # hidden channels
  254. self.cv1 = Conv(c1, c_, 1, 1)
  255. self.cv2 = Conv(c_ * 4, c2, 1, 1)
  256. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  257. def forward(self, x):
  258. x = self.cv1(x)
  259. with warnings.catch_warnings():
  260. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  261. y1 = self.m(x)
  262. y2 = self.m(y1)
  263. return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
  264. class Focus(nn.Module): # Focus:把宽度w和高度h的信息整合到c空间中。
  265. """
  266. Focus组件是为了减少计算量,提升速度。并不能增加网络的精度。
  267. 从高分辨率图片中,周期性的抽出像素点重构到低分辨率图像中,将图像相邻的四个位置进行堆叠,聚焦wh维度信息到c通道空间,提高每个点的感受野,并减少原始信息的丢失。
  268. 该组件在减少计算量,提升速度的前提下减少原始信息的丢失。
  269. """
  270. # Focus wh information into c-space
  271. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  272. """
  273. :param c1: 输入的channel数
  274. :type c1:
  275. :param c2: Focus输出的channel数
  276. :type c2:
  277. :param k: 卷积的kernel_size
  278. :type k:
  279. :param s: 卷积的stride
  280. :type s:
  281. :param p: 卷积的padding
  282. :type p:
  283. :param g: 卷积的groups 等于1为普通卷积 大于1为深度可分离卷积
  284. :type g:
  285. :param act:激活函数类型 True:SiLU/Swish False:不使用激活函数
  286. :type act:
  287. """
  288. super().__init__()
  289. self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
  290. # self.contract = Contract(gain=2)
  291. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  292. return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
  293. # return self.conv(self.contract(x))
  294. class GhostConv(nn.Module):
  295. # Ghost Convolution https://github.com/huawei-noah/ghostnet
  296. def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
  297. super().__init__()
  298. c_ = c2 // 2 # hidden channels
  299. self.cv1 = Conv(c1, c_, k, s, None, g, act)
  300. self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
  301. def forward(self, x):
  302. y = self.cv1(x)
  303. return torch.cat((y, self.cv2(y)), 1)
  304. class GhostBottleneck(nn.Module):
  305. # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
  306. def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
  307. super().__init__()
  308. c_ = c2 // 2
  309. self.conv = nn.Sequential(
  310. GhostConv(c1, c_, 1, 1), # pw
  311. DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
  312. GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
  313. self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1,
  314. act=False)) if s == 2 else nn.Identity()
  315. def forward(self, x):
  316. return self.conv(x) + self.shortcut(x)
  317. class Contract(nn.Module):
  318. # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
  319. def __init__(self, gain=2):
  320. """
  321. Focus模块的辅助函数,目的是改变输入特征的shape w和h维度的数据减半后将channel通道数提升4倍
  322. :param gain:
  323. :type gain:
  324. """
  325. super().__init__()
  326. self.gain = gain
  327. def forward(self, x):
  328. b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
  329. s = self.gain
  330. x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
  331. x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
  332. return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
  333. class Expand(nn.Module):
  334. # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
  335. def __init__(self, gain=2):
  336. """
  337. Contract函数的还原函数,目的是将channel维度(缩小4倍)的数据扩展到W和H维度(扩大两倍)
  338. :param gain:
  339. :type gain:
  340. """
  341. super().__init__()
  342. self.gain = gain
  343. def forward(self, x):
  344. b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
  345. s = self.gain
  346. x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
  347. x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
  348. return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
  349. class Concat(nn.Module):
  350. # Concatenate a list of tensors along dimension
  351. def __init__(self, dimension=1):
  352. """
  353. 按指定维度进行拼接
  354. :param dimension:维度
  355. :type dimension:
  356. """
  357. super().__init__()
  358. self.d = dimension
  359. def forward(self, x):
  360. return torch.cat(x, self.d)
  361. class DetectMultiBackend(nn.Module): # YOLOv5 多类型模型推理
  362. # YOLOv5 MultiBackend class for python inference on various backends
  363. def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False):
  364. # Usage:
  365. # PyTorch: weights = *.pt
  366. # TorchScript: *.torchscript
  367. # ONNX Runtime: *.onnx
  368. # ONNX OpenCV DNN: *.onnx with --dnn
  369. # OpenVINO: *.xml
  370. # CoreML: *.mlmodel
  371. # TensorRT: *.engine
  372. # TensorFlow SavedModel: *_saved_model
  373. # TensorFlow GraphDef: *.pb
  374. # TensorFlow Lite: *.tflite
  375. # TensorFlow Edge TPU: *_edgetpu.tflite
  376. from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
  377. super().__init__()
  378. w = str(weights[0] if isinstance(weights, list) else weights) # 获取 weights的名称
  379. pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend 返回模型的类型,如果模型属于该类则返回True
  380. stride, names = 32, [f'class{i}' for i in range(1000)] # assign defaults 自定义步长为32,类别为1000种
  381. w = attempt_download(w) # download if not local 下载权重文件,如果文件不存在
  382. fp16 &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16
  383. if data: # data.yaml path (optional) 如果yaml文件存在则读取文件种的class_name
  384. with open(data, errors='ignore') as f:
  385. names = yaml.safe_load(f)['names'] # class names
  386. if pt: # PyTorch
  387. model = attempt_load(weights if isinstance(weights, list) else w, map_location=device) # 加载权重文件
  388. stride = max(int(model.stride.max()), 32) # model stride #获取模型的下采样倍数(最小32倍)
  389. names = model.module.names if hasattr(model, 'module') else model.names # get class names 获取分类名称
  390. model.half() if fp16 else model.float() # 全精度/半精度
  391. self.model = model # explicitly assign for to(), cpu(), cuda(), half()
  392. elif jit: # TorchScript
  393. LOGGER.info(f'Loading {w} for TorchScript inference...')
  394. extra_files = {'config.txt': ''} # model metadata
  395. model = torch.jit.load(w, _extra_files=extra_files)
  396. model.half() if fp16 else model.float()
  397. if extra_files['config.txt']:
  398. d = json.loads(extra_files['config.txt']) # extra_files dict
  399. stride, names = int(d['stride']), d['names']
  400. elif dnn: # ONNX OpenCV DNN
  401. LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
  402. check_requirements(('opencv-python>=4.5.4',))
  403. net = cv2.dnn.readNetFromONNX(w)
  404. elif onnx: # ONNX Runtime
  405. LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
  406. cuda = torch.cuda.is_available()
  407. check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
  408. import onnxruntime
  409. providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
  410. session = onnxruntime.InferenceSession(w, providers=providers)
  411. meta = session.get_modelmeta().custom_metadata_map # metadata
  412. if 'stride' in meta:
  413. stride, names = int(meta['stride']), eval(meta['names'])
  414. elif xml: # OpenVINO
  415. LOGGER.info(f'Loading {w} for OpenVINO inference...')
  416. check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
  417. import openvino.inference_engine as ie
  418. core = ie.IECore()
  419. if not Path(w).is_file(): # if not *.xml
  420. w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
  421. network = core.read_network(model=w, weights=Path(w).with_suffix('.bin')) # *.xml, *.bin paths
  422. executable_network = core.load_network(network, device_name='CPU', num_requests=1)
  423. elif engine: # TensorRT
  424. LOGGER.info(f'Loading {w} for TensorRT inference...')
  425. import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
  426. check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
  427. Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
  428. logger = trt.Logger(trt.Logger.INFO)
  429. with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
  430. model = runtime.deserialize_cuda_engine(f.read())
  431. bindings = OrderedDict()
  432. fp16 = False # default updated below
  433. for index in range(model.num_bindings):
  434. name = model.get_binding_name(index)
  435. dtype = trt.nptype(model.get_binding_dtype(index))
  436. shape = tuple(model.get_binding_shape(index))
  437. data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
  438. bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
  439. if model.binding_is_input(index) and dtype == np.float16:
  440. fp16 = True
  441. binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
  442. context = model.create_execution_context()
  443. batch_size = bindings['images'].shape[0]
  444. elif coreml: # CoreML
  445. LOGGER.info(f'Loading {w} for CoreML inference...')
  446. import coremltools as ct
  447. model = ct.models.MLModel(w)
  448. else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
  449. if saved_model: # SavedModel
  450. LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
  451. import tensorflow as tf
  452. keras = False # assume TF1 saved_model
  453. model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
  454. elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
  455. LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
  456. import tensorflow as tf
  457. def wrap_frozen_graph(gd, inputs, outputs):
  458. x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
  459. ge = x.graph.as_graph_element
  460. return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
  461. gd = tf.Graph().as_graph_def() # graph_def
  462. with open(w, 'rb') as f:
  463. gd.ParseFromString(f.read())
  464. frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
  465. elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
  466. try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
  467. from tflite_runtime.interpreter import Interpreter, load_delegate
  468. except ImportError:
  469. import tensorflow as tf
  470. Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
  471. if edgetpu: # Edge TPU https://coral.ai/software/#edgetpu-runtime
  472. LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
  473. delegate = {
  474. 'Linux': 'libedgetpu.so.1',
  475. 'Darwin': 'libedgetpu.1.dylib',
  476. 'Windows': 'edgetpu.dll'}[platform.system()]
  477. interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
  478. else: # Lite
  479. LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
  480. interpreter = Interpreter(model_path=w) # load TFLite model
  481. interpreter.allocate_tensors() # allocate
  482. input_details = interpreter.get_input_details() # inputs
  483. output_details = interpreter.get_output_details() # outputs
  484. elif tfjs:
  485. raise Exception('ERROR: YOLOv5 TF.js inference is not supported')
  486. self.__dict__.update(locals()) # assign all variables to self
  487. def forward(self, im, augment=False, visualize=False, val=False):
  488. # YOLOv5 MultiBackend inference YOLOv5支持不同模型的推理
  489. b, ch, h, w = im.shape # batch, channel, height, width
  490. if self.pt: # PyTorch
  491. y = self.model(im, augment=augment, visualize=visualize)[0]
  492. elif self.jit: # TorchScript
  493. y = self.model(im)[0]
  494. elif self.dnn: # ONNX OpenCV DNN
  495. im = im.cpu().numpy() # torch to numpy
  496. self.net.setInput(im)
  497. y = self.net.forward()
  498. elif self.onnx: # ONNX Runtime
  499. im = im.cpu().numpy() # torch to numpy
  500. y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
  501. elif self.xml: # OpenVINO
  502. im = im.cpu().numpy() # FP32
  503. desc = self.ie.TensorDesc(precision='FP32', dims=im.shape, layout='NCHW') # Tensor Description
  504. request = self.executable_network.requests[0] # inference request
  505. request.set_blob(blob_name='images', blob=self.ie.Blob(desc, im)) # name=next(iter(request.input_blobs))
  506. request.infer()
  507. y = request.output_blobs['output'].buffer # name=next(iter(request.output_blobs))
  508. elif self.engine: # TensorRT
  509. assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
  510. self.binding_addrs['images'] = int(im.data_ptr())
  511. self.context.execute_v2(list(self.binding_addrs.values()))
  512. y = self.bindings['output'].data
  513. elif self.coreml: # CoreML
  514. im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
  515. im = Image.fromarray((im[0] * 255).astype('uint8'))
  516. # im = im.resize((192, 320), Image.ANTIALIAS)
  517. y = self.model.predict({'image': im}) # coordinates are xywh normalized
  518. if 'confidence' in y:
  519. box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
  520. conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
  521. y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
  522. else:
  523. k = 'var_' + str(sorted(int(k.replace('var_', '')) for k in y)[-1]) # output key
  524. y = y[k] # output
  525. else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
  526. im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
  527. if self.saved_model: # SavedModel
  528. y = (self.model(im, training=False) if self.keras else self.model(im)).numpy()
  529. elif self.pb: # GraphDef
  530. y = self.frozen_func(x=self.tf.constant(im)).numpy()
  531. else: # Lite or Edge TPU
  532. input, output = self.input_details[0], self.output_details[0]
  533. int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
  534. if int8:
  535. scale, zero_point = input['quantization']
  536. im = (im / scale + zero_point).astype(np.uint8) # de-scale
  537. self.interpreter.set_tensor(input['index'], im)
  538. self.interpreter.invoke()
  539. y = self.interpreter.get_tensor(output['index'])
  540. if int8:
  541. scale, zero_point = output['quantization']
  542. y = (y.astype(np.float32) - zero_point) * scale # re-scale
  543. y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
  544. if isinstance(y, np.ndarray):
  545. y = torch.tensor(y, device=self.device)
  546. return (y, []) if val else y
  547. def warmup(self, imgsz=(1, 3, 640, 640)): # 模型预热推理
  548. # Warmup model by running inference once
  549. if any((self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb)): # warmup types 检查模型类型
  550. if self.device.type != 'cpu': # only warmup GPU models
  551. im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input 初始化全零矩阵作为模型的输入
  552. for _ in range(2 if self.jit else 1): #
  553. self.forward(im) # warmup
  554. @staticmethod
  555. def model_type(p='path/to/model.pt'): # 根据模型的路径信息返回模型的类型
  556. # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
  557. from export import export_formats
  558. suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes 获取YOLOv5模型支持格式
  559. check_suffix(p, suffixes) # checks 检查模型后缀
  560. p = Path(p).name # eliminate trailing separators 去除目录信息
  561. pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, xml2 = (s in p for s in suffixes)
  562. xml |= xml2 # *_openvino_model or *.xml
  563. tflite &= not edgetpu # *.tflite
  564. return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs
  565. class AutoShape(nn.Module): #自动调整shape
  566. # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
  567. conf = 0.25 # NMS confidence threshold
  568. iou = 0.45 # NMS IoU threshold
  569. agnostic = False # NMS class-agnostic
  570. multi_label = False # NMS multiple labels per box
  571. classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
  572. max_det = 1000 # maximum number of detections per image
  573. amp = False # Automatic Mixed Precision (AMP) inference
  574. def __init__(self, model):
  575. super().__init__()
  576. LOGGER.info('Adding AutoShape... ')
  577. copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
  578. self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
  579. self.pt = not self.dmb or model.pt # PyTorch model
  580. self.model = model.eval()
  581. def _apply(self, fn):
  582. # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
  583. self = super()._apply(fn)
  584. if self.pt:
  585. m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
  586. m.stride = fn(m.stride)
  587. m.grid = list(map(fn, m.grid))
  588. if isinstance(m.anchor_grid, list):
  589. m.anchor_grid = list(map(fn, m.anchor_grid))
  590. return self
  591. @torch.no_grad()
  592. def forward(self, imgs, size=640, augment=False, profile=False):
  593. # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
  594. # file: imgs = 'data/images/zidane.jpg' # str or PosixPath
  595. # URI: = 'https://ultralytics.com/images/zidane.jpg'
  596. # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
  597. # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
  598. # numpy: = np.zeros((640,1280,3)) # HWC
  599. # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
  600. # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
  601. t = [time_sync()]
  602. p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
  603. autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
  604. if isinstance(imgs, torch.Tensor): # torch
  605. with amp.autocast(autocast):
  606. return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
  607. # Pre-process
  608. n, imgs = (len(imgs), list(imgs)) if isinstance(imgs, (list, tuple)) else (1, [imgs]) # number, list of images
  609. shape0, shape1, files = [], [], [] # image and inference shapes, filenames
  610. for i, im in enumerate(imgs):
  611. f = f'image{i}' # filename
  612. if isinstance(im, (str, Path)): # filename or uri
  613. im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
  614. im = np.asarray(exif_transpose(im))
  615. elif isinstance(im, Image.Image): # PIL Image
  616. im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
  617. files.append(Path(f).with_suffix('.jpg').name)
  618. if im.shape[0] < 5: # image in CHW
  619. im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
  620. im = im[..., :3] if im.ndim == 3 else np.tile(im[..., None], 3) # enforce 3ch input
  621. s = im.shape[:2] # HWC
  622. shape0.append(s) # image shape
  623. g = (size / max(s)) # gain
  624. shape1.append([y * g for y in s])
  625. imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
  626. shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape
  627. x = [letterbox(im, shape1, auto=False)[0] for im in imgs] # pad
  628. x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
  629. x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
  630. t.append(time_sync())
  631. with amp.autocast(autocast):
  632. # Inference
  633. y = self.model(x, augment, profile) # forward
  634. t.append(time_sync())
  635. # Post-process
  636. y = non_max_suppression(y if self.dmb else y[0],
  637. self.conf,
  638. self.iou,
  639. self.classes,
  640. self.agnostic,
  641. self.multi_label,
  642. max_det=self.max_det) # NMS
  643. for i in range(n):
  644. scale_coords(shape1, y[i][:, :4], shape0[i])
  645. t.append(time_sync())
  646. return Detections(imgs, y, files, t, self.names, x.shape)
  647. class Detections:
  648. # YOLOv5 detections class for inference results
  649. def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None):
  650. super().__init__()
  651. d = pred[0].device # device
  652. gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations
  653. self.imgs = imgs # list of images as numpy arrays
  654. self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
  655. self.names = names # class names
  656. self.files = files # image filenames
  657. self.times = times # profiling times
  658. self.xyxy = pred # xyxy pixels
  659. self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
  660. self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
  661. self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
  662. self.n = len(self.pred) # number of images (batch size)
  663. self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
  664. self.s = shape # inference BCHW shape
  665. def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
  666. crops = []
  667. for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
  668. s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
  669. if pred.shape[0]:
  670. for c in pred[:, -1].unique():
  671. n = (pred[:, -1] == c).sum() # detections per class
  672. s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
  673. if show or save or render or crop:
  674. annotator = Annotator(im, example=str(self.names))
  675. for *box, conf, cls in reversed(pred): # xyxy, confidence, class
  676. label = f'{self.names[int(cls)]} {conf:.2f}'
  677. if crop:
  678. file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
  679. crops.append({
  680. 'box': box,
  681. 'conf': conf,
  682. 'cls': cls,
  683. 'label': label,
  684. 'im': save_one_box(box, im, file=file, save=save)})
  685. else: # all others
  686. annotator.box_label(box, label if labels else '', color=colors(cls))
  687. im = annotator.im
  688. else:
  689. s += '(no detections)'
  690. im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
  691. if pprint:
  692. LOGGER.info(s.rstrip(', '))
  693. if show:
  694. im.show(self.files[i]) # show
  695. if save:
  696. f = self.files[i]
  697. im.save(save_dir / f) # save
  698. if i == self.n - 1:
  699. LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
  700. if render:
  701. self.imgs[i] = np.asarray(im)
  702. if crop:
  703. if save:
  704. LOGGER.info(f'Saved results to {save_dir}\n')
  705. return crops
  706. def print(self):
  707. self.display(pprint=True) # print results
  708. LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
  709. self.t)
  710. def show(self, labels=True):
  711. self.display(show=True, labels=labels) # show results
  712. def save(self, labels=True, save_dir='runs/detect/exp'):
  713. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir
  714. self.display(save=True, labels=labels, save_dir=save_dir) # save results
  715. def crop(self, save=True, save_dir='runs/detect/exp'):
  716. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) if save else None
  717. return self.display(crop=True, save=save, save_dir=save_dir) # crop results
  718. def render(self, labels=True):
  719. self.display(render=True, labels=labels) # render results
  720. return self.imgs
  721. def pandas(self):
  722. # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
  723. new = copy(self) # return copy
  724. ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
  725. cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
  726. for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
  727. a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
  728. setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
  729. return new
  730. def tolist(self):
  731. # return a list of Detections objects, i.e. 'for result in results.tolist():'
  732. r = range(self.n) # iterable
  733. x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
  734. # for d in x:
  735. # for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
  736. # setattr(d, k, getattr(d, k)[0]) # pop out of list
  737. return x
  738. def __len__(self):
  739. return self.n
  740. class Classify(nn.Module):
  741. # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
  742. def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
  743. super().__init__()
  744. self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
  745. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
  746. self.flat = nn.Flatten()
  747. def forward(self, x):
  748. z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
  749. return self.flat(self.conv(z)) # flatten to x(b,c2)