dht.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. from model._cdht.dht_func import C_dht
  5. class DHT_Layer(nn.Module):
  6. def __init__(self, input_dim, dim, numAngle, numRho):
  7. super(DHT_Layer, self).__init__()
  8. self.fist_conv = nn.Sequential(
  9. nn.Conv2d(input_dim, dim, 1),
  10. nn.BatchNorm2d(dim),
  11. nn.ReLU()
  12. )
  13. self.dht = DHT(numAngle=numAngle, numRho=numRho)
  14. self.convs = nn.Sequential(
  15. nn.Conv2d(dim, dim, 3, 1, 1),
  16. nn.BatchNorm2d(dim),
  17. nn.ReLU(),
  18. nn.Conv2d(dim, dim, 3, 1, 1),
  19. nn.BatchNorm2d(dim),
  20. nn.ReLU()
  21. )
  22. def forward(self, x):
  23. x = self.fist_conv(x)
  24. x = self.dht(x)
  25. x = self.convs(x)
  26. return x
  27. class DHT(nn.Module):
  28. def __init__(self, numAngle, numRho):
  29. super(DHT, self).__init__()
  30. self.line_agg = C_dht(numAngle, numRho)
  31. def forward(self, x):
  32. accum = self.line_agg(x)
  33. return accum