experimental.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Experimental modules
  4. """
  5. import math
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. from dependence.yolov5.models.common import Conv
  10. from dependence.yolov5.utils.downloads import attempt_download
  11. class Sum(nn.Module):
  12. # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
  13. def __init__(self, n, weight=False): # n: number of inputs
  14. super().__init__()
  15. self.weight = weight # apply weights boolean
  16. self.iter = range(n - 1) # iter object
  17. if weight:
  18. self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True) # layer weights
  19. def forward(self, x):
  20. y = x[0] # no weight
  21. if self.weight:
  22. w = torch.sigmoid(self.w) * 2
  23. for i in self.iter:
  24. y = y + x[i + 1] * w[i]
  25. else:
  26. for i in self.iter:
  27. y = y + x[i + 1]
  28. return y
  29. class MixConv2d(nn.Module):
  30. # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
  31. def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
  32. super().__init__()
  33. n = len(k) # number of convolutions
  34. if equal_ch: # equal c_ per group
  35. i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
  36. c_ = [(i == g).sum() for g in range(n)] # intermediate channels
  37. else: # equal weight.numel() per group
  38. b = [c2] + [0] * n
  39. a = np.eye(n + 1, n, k=-1)
  40. a -= np.roll(a, 1, axis=1)
  41. a *= np.array(k) ** 2
  42. a[0] = 1
  43. c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
  44. self.m = nn.ModuleList([
  45. nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
  46. self.bn = nn.BatchNorm2d(c2)
  47. self.act = nn.SiLU()
  48. def forward(self, x):
  49. return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
  50. class Ensemble(nn.ModuleList):
  51. # Ensemble of models
  52. def __init__(self):
  53. super().__init__()
  54. def forward(self, x, augment=False, profile=False, visualize=False):
  55. y = [module(x, augment, profile, visualize)[0] for module in self]
  56. # y = torch.stack(y).max(0)[0] # max ensemble
  57. # y = torch.stack(y).mean(0) # mean ensemble
  58. y = torch.cat(y, 1) # nms ensemble
  59. return y, None # inference, train output
  60. def attempt_load(weights, device=None, inplace=True, fuse=True):
  61. # Loads an ensemble of models weights=[a,b,c] or a single save_models weights=[a] or weights=a
  62. from dependence.yolov5.models.yolo import Detect, Model
  63. model = Ensemble()
  64. for w in weights if isinstance(weights, list) else [weights]:
  65. ckpt = torch.load(attempt_download(w), map_location='cpu') # load
  66. ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 save_models
  67. model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused save_models in eval mode
  68. # Compatibility updates
  69. for m in model.modules():
  70. t = type(m)
  71. if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
  72. m.inplace = inplace # torch 1.7.0 compatibility
  73. if t is Detect and not isinstance(m.anchor_grid, list):
  74. delattr(m, 'anchor_grid')
  75. setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
  76. elif t is Conv:
  77. m._non_persistent_buffers_set = set() # torch 1.6.0 compatibility
  78. elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
  79. m.recompute_scale_factor = None # torch 1.11.0 compatibility
  80. if len(model) == 1:
  81. return model[-1] # return save_models
  82. print(f'Ensemble created with {weights}\n')
  83. for k in 'names', 'nc', 'yaml':
  84. setattr(model, k, getattr(model[0], k))
  85. model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
  86. assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
  87. return model # return ensemble