metrics.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Model validation metrics
  4. """
  5. import math
  6. import warnings
  7. from pathlib import Path
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. import torch
  11. def fitness(x):
  12. # Model fitness as a weighted combination of metrics
  13. w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
  14. return (x[:, :4] * w).sum(1)
  15. def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16):
  16. """ Compute the average precision, given the recall and precision curves.
  17. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
  18. # Arguments
  19. tp: True positives (nparray, nx1 or nx10).
  20. conf: Objectness value from 0-1 (nparray).
  21. pred_cls: Predicted object classes (nparray).
  22. target_cls: True object classes (nparray).
  23. plot: Plot precision-recall curve at mAP@0.5
  24. save_dir: Plot save directory
  25. # Returns
  26. The average precision as computed in py-faster-rcnn.
  27. """
  28. # Sort by objectness
  29. i = np.argsort(-conf)# 对conf排序,加负号,conf大的元素在前面 i为索引值
  30. tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
  31. # Find unique classes
  32. unique_classes, nt = np.unique(target_cls, return_counts=True)
  33. nc = unique_classes.shape[0] # number of classes, number of detections #总类别数
  34. # Create Precision-Recall curve and compute AP for each class
  35. px, py = np.linspace(0, 1, 1000), [] # for plotting
  36. ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
  37. for ci, c in enumerate(unique_classes):
  38. i = pred_cls == c
  39. n_l = nt[ci] # number of labels 真实标签数
  40. n_p = i.sum() # number of predictions 预测出来的表签数
  41. if n_p == 0 or n_l == 0:
  42. continue
  43. else:
  44. # 混淆矩阵计算
  45. # Accumulate FPs and TPs
  46. fpc = (1 - tp[i]).cumsum(0)
  47. tpc = tp[i].cumsum(0)
  48. # Recall 召回
  49. recall = tpc / (n_l + eps) # recall curve
  50. r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
  51. # Precision 精度
  52. precision = tpc / (tpc + fpc) # precision curve
  53. p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
  54. # AP from recall-precision curve 召回-精度图表
  55. for j in range(tp.shape[1]):
  56. ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
  57. if plot and j == 0:
  58. py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
  59. # Compute F1 (harmonic mean of precision and recall) 计算F1 score
  60. f1 = 2 * p * r / (p + r + eps)
  61. names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
  62. names = {i: v for i, v in enumerate(names)} # to dict
  63. if plot: # 画图
  64. plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
  65. plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
  66. plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision')
  67. plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')
  68. i = f1.mean(0).argmax() # max F1 index
  69. p, r, f1 = p[:, i], r[:, i], f1[:, i]
  70. tp = (r * nt).round() # true positives
  71. fp = (tp / (p + eps) - tp).round() # false positives
  72. return tp, fp, p, r, f1, ap, unique_classes.astype('int32')
  73. def compute_ap(recall, precision):
  74. """ Compute the average precision, given the recall and precision curves
  75. # Arguments
  76. recall: The recall curve (list)
  77. precision: The precision curve (list)
  78. # Returns
  79. Average precision, precision curve, recall curve
  80. """
  81. # Append sentinel values to beginning and end
  82. mrec = np.concatenate(([0.0], recall, [1.0])) #把recall开放的区间给补上,补成了闭合的区间
  83. mpre = np.concatenate(([1.0], precision, [0.0])) # mpre也是做了对应的补偿
  84. # Compute the precision envelope
  85. # 人为把pre-rec曲线变成单调递减。将曲线填顺滑
  86. mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
  87. # Integrate area under curve
  88. method = 'interp' # methods: 'continuous', 'interp'
  89. if method == 'interp':
  90. x = np.linspace(0, 1, 101) # 101-point interp (COCO)
  91. ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
  92. else: # 'continuous'
  93. i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
  94. ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
  95. return ap, mpre, mrec
  96. class ConfusionMatrix:
  97. # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
  98. def __init__(self, nc, conf=0.25, iou_thres=0.45):
  99. self.matrix = np.zeros((nc + 1, nc + 1))
  100. self.nc = nc # number of classes
  101. self.conf = conf
  102. self.iou_thres = iou_thres
  103. def process_batch(self, detections, labels):
  104. """
  105. Return intersection-over-union (Jaccard index) of boxes.
  106. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  107. Arguments:
  108. detections (Array[N, 6]), x1, y1, x2, y2, conf, class
  109. labels (Array[M, 5]), class, x1, y1, x2, y2
  110. Returns:
  111. None, updates confusion matrix accordingly
  112. """
  113. detections = detections[detections[:, 4] > self.conf]
  114. gt_classes = labels[:, 0].int()
  115. detection_classes = detections[:, 5].int()
  116. iou = box_iou(labels[:, 1:], detections[:, :4])
  117. x = torch.where(iou > self.iou_thres)
  118. if x[0].shape[0]:
  119. matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
  120. if x[0].shape[0] > 1:
  121. matches = matches[matches[:, 2].argsort()[::-1]]
  122. matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
  123. matches = matches[matches[:, 2].argsort()[::-1]]
  124. matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
  125. else:
  126. matches = np.zeros((0, 3))
  127. n = matches.shape[0] > 0
  128. m0, m1, _ = matches.transpose().astype(np.int16)
  129. for i, gc in enumerate(gt_classes):
  130. j = m0 == i
  131. if n and sum(j) == 1:
  132. self.matrix[detection_classes[m1[j]], gc] += 1 # correct
  133. else:
  134. self.matrix[self.nc, gc] += 1 # background FP
  135. if n:
  136. for i, dc in enumerate(detection_classes):
  137. if not any(m1 == i):
  138. self.matrix[dc, self.nc] += 1 # background FN
  139. def matrix(self):
  140. return self.matrix
  141. def tp_fp(self):
  142. tp = self.matrix.diagonal() # true positives
  143. fp = self.matrix.sum(1) - tp # false positives
  144. # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
  145. return tp[:-1], fp[:-1] # remove background class
  146. def plot(self, normalize=True, save_dir='', names=()):
  147. try:
  148. import seaborn as sn
  149. array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
  150. array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
  151. fig = plt.figure(figsize=(12, 9), tight_layout=True)
  152. nc, nn = self.nc, len(names) # number of classes, names
  153. sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
  154. labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
  155. with warnings.catch_warnings():
  156. warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
  157. sn.heatmap(array,
  158. annot=nc < 30,
  159. annot_kws={
  160. "size": 8},
  161. cmap='Blues',
  162. fmt='.2f',
  163. square=True,
  164. vmin=0.0,
  165. xticklabels=names + ['background FP'] if labels else "auto",
  166. yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
  167. fig.axes[0].set_xlabel('True')
  168. fig.axes[0].set_ylabel('Predicted')
  169. fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
  170. plt.close()
  171. except Exception as e:
  172. print(f'WARNING: ConfusionMatrix plot failure: {e}')
  173. def print(self):
  174. for i in range(self.nc + 1):
  175. print(' '.join(map(str, self.matrix[i])))
  176. # 计算两个框的特定iou(GIoU, DIoU, CIoU)
  177. def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  178. # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)包围框
  179. # Get the coordinates of bounding boxes
  180. if xywh: # transform from xywh to xyxy 将yolo格式转换为voc格式
  181. (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, 1), box2.chunk(4, 1)
  182. w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
  183. b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
  184. b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
  185. else: # x1, y1, x2, y2 = box1
  186. b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, 1)
  187. b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, 1)
  188. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  189. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  190. '''
  191. left_line = max(box1[1], box2[1])
  192. right_line = min(box1[3], box2[3])
  193. top_line = max(box1[0], box2[0])
  194. bottom_line = min(box1[2], box2[2])
  195. intersect = (right_line - left_line) * (bottom_line - top_line)
  196. '''
  197. # Intersection area 面积交集
  198. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  199. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  200. # Union Area 面积并集
  201. union = w1 * h1 + w2 * h2 - inter + eps
  202. # IoU
  203. iou = inter / union
  204. if CIoU or DIoU or GIoU:
  205. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width 最小包围框的宽
  206. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height 最小包围框的高
  207. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  208. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared 包围框对角线
  209. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2 两个框中心点距离的平方
  210. if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  211. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  212. with torch.no_grad():
  213. alpha = v / (v - iou + (1 + eps))
  214. return iou - (rho2 / c2 + v * alpha) # CIoU
  215. return iou - rho2 / c2 # DIoU
  216. c_area = cw * ch + eps # convex area
  217. return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
  218. return iou # IoU
  219. def box_area(box):
  220. # box = xyxy(4,n)
  221. return (box[2] - box[0]) * (box[3] - box[1])
  222. # 计算两个框的IOU 普通
  223. def box_iou(box1, box2):
  224. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  225. """
  226. Return intersection-over-union (Jaccard index) of boxes.
  227. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  228. Arguments:
  229. box1 (Tensor[N, 4])
  230. box2 (Tensor[M, 4])
  231. Returns:
  232. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  233. IoU values for every element in boxes1 and boxes2
  234. """
  235. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  236. (a1, a2), (b1, b2) = box1[:, None].chunk(2, 2), box2.chunk(2, 1)
  237. inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
  238. # IoU = inter / (area1 + area2 - inter)
  239. return inter / (box_area(box1.T)[:, None] + box_area(box2.T) - inter)
  240. def bbox_ioa(box1, box2, eps=1E-7):
  241. """ Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
  242. box1: np.array of shape(4)
  243. box2: np.array of shape(nx4)
  244. returns: np.array of shape(n)
  245. """
  246. # Get the coordinates of bounding boxes
  247. b1_x1, b1_y1, b1_x2, b1_y2 = box1
  248. b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
  249. # Intersection area
  250. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  251. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  252. # box2 area
  253. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
  254. # Intersection over box2 area
  255. return inter_area / box2_area
  256. # 根据两个框的宽高矩阵返回IOU
  257. def wh_iou(wh1, wh2):
  258. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  259. wh1 = wh1[:, None] # [N,1,2]
  260. wh2 = wh2[None] # [1,M,2]
  261. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  262. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  263. # Plots ----------------------------------------------------------------------------------------------------------------
  264. def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
  265. # Precision-recall curve
  266. fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
  267. py = np.stack(py, axis=1)
  268. if 0 < len(names) < 21: # display per-class legend if < 21 classes
  269. for i, y in enumerate(py.T):
  270. ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
  271. else:
  272. ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
  273. ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
  274. ax.set_xlabel('Recall')
  275. ax.set_ylabel('Precision')
  276. ax.set_xlim(0, 1)
  277. ax.set_ylim(0, 1)
  278. plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
  279. fig.savefig(Path(save_dir), dpi=250)
  280. plt.close()
  281. def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'):
  282. # Metric-confidence curve
  283. fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
  284. if 0 < len(names) < 21: # display per-class legend if < 21 classes
  285. for i, y in enumerate(py):
  286. ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
  287. else:
  288. ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
  289. y = py.mean(0)
  290. ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
  291. ax.set_xlabel(xlabel)
  292. ax.set_ylabel(ylabel)
  293. ax.set_xlim(0, 1)
  294. ax.set_ylim(0, 1)
  295. plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
  296. fig.savefig(Path(save_dir), dpi=250)
  297. plt.close()