metrics.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Model validation metrics
  4. """
  5. import math
  6. import warnings
  7. from pathlib import Path
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. import torch
  11. def fitness(x):
  12. # Model fitness as a weighted combination of metrics
  13. w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
  14. return (x[:, :4] * w).sum(1)
  15. def smooth(y, f=0.05):
  16. # Box filter of fraction f
  17. nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
  18. p = np.ones(nf // 2) # ones padding
  19. yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
  20. return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
  21. def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16):
  22. """ Compute the average precision, given the recall and precision curves.
  23. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
  24. # Arguments
  25. tp: True positives (nparray, nx1 or nx10).
  26. conf: Objectness value from 0-1 (nparray).
  27. pred_cls: Predicted object classes (nparray).
  28. target_cls: True object classes (nparray).
  29. plot: Plot precision-recall curve at mAP@0.5
  30. save_dir: Plot save directory
  31. # Returns
  32. The average precision as computed in py-faster-rcnn.
  33. """
  34. # Sort by objectness
  35. i = np.argsort(-conf)
  36. tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
  37. # Find unique classes
  38. unique_classes, nt = np.unique(target_cls, return_counts=True)
  39. nc = unique_classes.shape[0] # number of classes, number of detections
  40. # Create Precision-Recall curve and compute AP for each class
  41. px, py = np.linspace(0, 1, 1000), [] # for plotting
  42. ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
  43. for ci, c in enumerate(unique_classes):
  44. i = pred_cls == c
  45. n_l = nt[ci] # number of labels
  46. n_p = i.sum() # number of predictions
  47. if n_p == 0 or n_l == 0:
  48. continue
  49. # Accumulate FPs and TPs
  50. fpc = (1 - tp[i]).cumsum(0)
  51. tpc = tp[i].cumsum(0)
  52. # Recall
  53. recall = tpc / (n_l + eps) # recall curve
  54. r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
  55. # Precision
  56. precision = tpc / (tpc + fpc) # precision curve
  57. p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
  58. # AP from recall-precision curve
  59. for j in range(tp.shape[1]):
  60. ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
  61. if plot and j == 0:
  62. py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
  63. # Compute F1 (harmonic mean of precision and recall)
  64. f1 = 2 * p * r / (p + r + eps)
  65. names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
  66. names = dict(enumerate(names)) # to dict
  67. if plot:
  68. plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
  69. plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
  70. plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision')
  71. plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')
  72. i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
  73. p, r, f1 = p[:, i], r[:, i], f1[:, i]
  74. tp = (r * nt).round() # true positives
  75. fp = (tp / (p + eps) - tp).round() # false positives
  76. return tp, fp, p, r, f1, ap, unique_classes.astype(int)
  77. def compute_ap(recall, precision):
  78. """ Compute the average precision, given the recall and precision curves
  79. # Arguments
  80. recall: The recall curve (list)
  81. precision: The precision curve (list)
  82. # Returns
  83. Average precision, precision curve, recall curve
  84. """
  85. # Append sentinel values to beginning and end
  86. mrec = np.concatenate(([0.0], recall, [1.0]))
  87. mpre = np.concatenate(([1.0], precision, [0.0]))
  88. # Compute the precision envelope
  89. mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
  90. # Integrate area under curve
  91. method = 'interp' # methods: 'continuous', 'interp'
  92. if method == 'interp':
  93. x = np.linspace(0, 1, 101) # 101-point interp (COCO)
  94. ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
  95. else: # 'continuous'
  96. i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
  97. ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
  98. return ap, mpre, mrec
  99. class ConfusionMatrix:
  100. # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
  101. def __init__(self, nc, conf=0.25, iou_thres=0.45):
  102. self.matrix = np.zeros((nc + 1, nc + 1))
  103. self.nc = nc # number of classes
  104. self.conf = conf
  105. self.iou_thres = iou_thres
  106. def process_batch(self, detections, labels):
  107. """
  108. Return intersection-over-union (Jaccard index) of boxes.
  109. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  110. Arguments:
  111. detections (Array[N, 6]), x1, y1, x2, y2, conf, class
  112. labels (Array[M, 5]), class, x1, y1, x2, y2
  113. Returns:
  114. None, updates confusion matrix accordingly
  115. """
  116. detections = detections[detections[:, 4] > self.conf]
  117. gt_classes = labels[:, 0].int()
  118. detection_classes = detections[:, 5].int()
  119. iou = box_iou(labels[:, 1:], detections[:, :4])
  120. x = torch.where(iou > self.iou_thres)
  121. if x[0].shape[0]:
  122. matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
  123. if x[0].shape[0] > 1:
  124. matches = matches[matches[:, 2].argsort()[::-1]]
  125. matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
  126. matches = matches[matches[:, 2].argsort()[::-1]]
  127. matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
  128. else:
  129. matches = np.zeros((0, 3))
  130. n = matches.shape[0] > 0
  131. m0, m1, _ = matches.transpose().astype(int)
  132. for i, gc in enumerate(gt_classes):
  133. j = m0 == i
  134. if n and sum(j) == 1:
  135. self.matrix[detection_classes[m1[j]], gc] += 1 # correct
  136. else:
  137. self.matrix[self.nc, gc] += 1 # background FP
  138. if n:
  139. for i, dc in enumerate(detection_classes):
  140. if not any(m1 == i):
  141. self.matrix[dc, self.nc] += 1 # background FN
  142. def matrix(self):
  143. return self.matrix
  144. def tp_fp(self):
  145. tp = self.matrix.diagonal() # true positives
  146. fp = self.matrix.sum(1) - tp # false positives
  147. # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
  148. return tp[:-1], fp[:-1] # remove background class
  149. def plot(self, normalize=True, save_dir='', names=()):
  150. try:
  151. import seaborn as sn
  152. array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
  153. array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
  154. fig = plt.figure(figsize=(12, 9), tight_layout=True)
  155. nc, nn = self.nc, len(names) # number of classes, names
  156. sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
  157. labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
  158. with warnings.catch_warnings():
  159. warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
  160. sn.heatmap(array,
  161. annot=nc < 30,
  162. annot_kws={
  163. "size": 8},
  164. cmap='Blues',
  165. fmt='.2f',
  166. square=True,
  167. vmin=0.0,
  168. xticklabels=names + ['background FP'] if labels else "auto",
  169. yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
  170. fig.axes[0].set_xlabel('True')
  171. fig.axes[0].set_ylabel('Predicted')
  172. fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
  173. plt.close()
  174. except Exception as e:
  175. print(f'WARNING: ConfusionMatrix plot failure: {e}')
  176. def print(self):
  177. for i in range(self.nc + 1):
  178. print(' '.join(map(str, self.matrix[i])))
  179. def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  180. # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
  181. # Get the coordinates of bounding boxes
  182. if xywh: # transform from xywh to xyxy
  183. (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, 1), box2.chunk(4, 1)
  184. w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
  185. b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
  186. b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
  187. else: # x1, y1, x2, y2 = box1
  188. b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, 1)
  189. b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, 1)
  190. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  191. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  192. # Intersection area
  193. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  194. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  195. # Union Area
  196. union = w1 * h1 + w2 * h2 - inter + eps
  197. # IoU
  198. iou = inter / union
  199. if CIoU or DIoU or GIoU:
  200. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  201. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  202. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  203. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  204. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
  205. if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  206. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  207. with torch.no_grad():
  208. alpha = v / (v - iou + (1 + eps))
  209. return iou - (rho2 / c2 + v * alpha) # CIoU
  210. return iou - rho2 / c2 # DIoU
  211. c_area = cw * ch + eps # convex area
  212. return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
  213. return iou # IoU
  214. def box_area(box):
  215. # box = xyxy(4,n)
  216. return (box[2] - box[0]) * (box[3] - box[1])
  217. def box_iou(box1, box2):
  218. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  219. """
  220. Return intersection-over-union (Jaccard index) of boxes.
  221. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  222. Arguments:
  223. box1 (Tensor[N, 4])
  224. box2 (Tensor[M, 4])
  225. Returns:
  226. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  227. IoU values for every element in boxes1 and boxes2
  228. """
  229. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  230. (a1, a2), (b1, b2) = box1[:, None].chunk(2, 2), box2.chunk(2, 1)
  231. inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
  232. # IoU = inter / (area1 + area2 - inter)
  233. return inter / (box_area(box1.T)[:, None] + box_area(box2.T) - inter)
  234. def bbox_ioa(box1, box2, eps=1E-7):
  235. """ Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
  236. box1: np.array of shape(4)
  237. box2: np.array of shape(nx4)
  238. returns: np.array of shape(n)
  239. """
  240. # Get the coordinates of bounding boxes
  241. b1_x1, b1_y1, b1_x2, b1_y2 = box1
  242. b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
  243. # Intersection area
  244. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  245. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  246. # box2 area
  247. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
  248. # Intersection over box2 area
  249. return inter_area / box2_area
  250. def wh_iou(wh1, wh2):
  251. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  252. wh1 = wh1[:, None] # [N,1,2]
  253. wh2 = wh2[None] # [1,M,2]
  254. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  255. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  256. # Plots ----------------------------------------------------------------------------------------------------------------
  257. def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
  258. # Precision-recall curve
  259. fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
  260. py = np.stack(py, axis=1)
  261. if 0 < len(names) < 21: # display per-class legend if < 21 classes
  262. for i, y in enumerate(py.T):
  263. ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
  264. else:
  265. ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
  266. ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
  267. ax.set_xlabel('Recall')
  268. ax.set_ylabel('Precision')
  269. ax.set_xlim(0, 1)
  270. ax.set_ylim(0, 1)
  271. plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
  272. fig.savefig(save_dir, dpi=250)
  273. plt.close()
  274. def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
  275. # Metric-confidence curve
  276. fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
  277. if 0 < len(names) < 21: # display per-class legend if < 21 classes
  278. for i, y in enumerate(py):
  279. ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
  280. else:
  281. ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
  282. y = smooth(py.mean(0), 0.05)
  283. ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
  284. ax.set_xlabel(xlabel)
  285. ax.set_ylabel(ylabel)
  286. ax.set_xlim(0, 1)
  287. ax.set_ylim(0, 1)
  288. plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
  289. fig.savefig(save_dir, dpi=250)
  290. plt.close()