experimental.py 4.8 KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Experimental modules
  4. """
  5. import math
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. from models.common import Conv
  10. from utils.downloads import attempt_download
  11. class CrossConv(nn.Module):
  12. # Cross Convolution Downsample
  13. def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
  14. # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
  15. super().__init__()
  16. c_ = int(c2 * e) # hidden channels
  17. self.cv1 = Conv(c1, c_, (1, k), (1, s))
  18. self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
  19. self.add = shortcut and c1 == c2
  20. def forward(self, x):
  21. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  22. class Sum(nn.Module):
  23. # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
  24. def __init__(self, n, weight=False): # n: number of inputs
  25. super().__init__()
  26. self.weight = weight # apply weights boolean
  27. self.iter = range(n - 1) # iter object
  28. if weight:
  29. self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True) # layer weights
  30. def forward(self, x):
  31. y = x[0] # no weight
  32. if self.weight:
  33. w = torch.sigmoid(self.w) * 2
  34. for i in self.iter:
  35. y = y + x[i + 1] * w[i]
  36. else:
  37. for i in self.iter:
  38. y = y + x[i + 1]
  39. return y
  40. class MixConv2d(nn.Module):
  41. # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
  42. def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
  43. super().__init__()
  44. n = len(k) # number of convolutions
  45. if equal_ch: # equal c_ per group
  46. i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
  47. c_ = [(i == g).sum() for g in range(n)] # intermediate channels
  48. else: # equal weight.numel() per group
  49. b = [c2] + [0] * n
  50. a = np.eye(n + 1, n, k=-1)
  51. a -= np.roll(a, 1, axis=1)
  52. a *= np.array(k) ** 2
  53. a[0] = 1
  54. c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
  55. self.m = nn.ModuleList([
  56. nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
  57. self.bn = nn.BatchNorm2d(c2)
  58. self.act = nn.SiLU()
  59. def forward(self, x):
  60. return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
  61. class Ensemble(nn.ModuleList):
  62. # Ensemble of models
  63. def __init__(self):
  64. super().__init__()
  65. def forward(self, x, augment=False, profile=False, visualize=False):
  66. y = []
  67. for module in self:
  68. y.append(module(x, augment, profile, visualize)[0])
  69. # y = torch.stack(y).max(0)[0] # max ensemble
  70. # y = torch.stack(y).mean(0) # mean ensemble
  71. y = torch.cat(y, 1) # nms ensemble
  72. return y, None # inference, train output
  73. def attempt_load(weights, map_location=None, inplace=True, fuse=True):
  74. from models.yolo import Detect, Model
  75. # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
  76. model = Ensemble() # 初始化
  77. for w in weights if isinstance(weights, list) else [weights]:
  78. ckpt = torch.load(attempt_download(w), map_location=map_location) # load 读取权重信息
  79. ckpt = (ckpt.get('ema') or ckpt['model']).float() # FP32 model
  80. model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode
  81. # Compatibility updates 检查模型每个网络结构与torch的兼容性
  82. for m in model.modules():
  83. t = type(m)
  84. if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
  85. m.inplace = inplace # torch 1.7.0 compatibility
  86. if t is Detect:
  87. if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
  88. delattr(m, 'anchor_grid')
  89. setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
  90. elif t is Conv:
  91. m._non_persistent_buffers_set = set() # torch 1.6.0 compatibility
  92. elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
  93. m.recompute_scale_factor = None # torch 1.11.0 compatibility
  94. if len(model) == 1:
  95. return model[-1] # return model
  96. else:
  97. print(f'Ensemble created with {weights}\n')
  98. for k in 'names', 'nc', 'yaml':
  99. setattr(model, k, getattr(model[0], k))
  100. model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
  101. assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
  102. return model # return ensemble