12345678910111213141516171819202122232425262728293031323334353637 |
- import torch
- import torch.nn as nn
- import numpy as np
- from model._cdht.dht_func import C_dht
- class DHT_Layer(nn.Module):
- def __init__(self, input_dim, dim, numAngle, numRho):
- super(DHT_Layer, self).__init__()
- self.fist_conv = nn.Sequential(
- nn.Conv2d(input_dim, dim, 1),
- nn.BatchNorm2d(dim),
- nn.ReLU()
- )
- self.dht = DHT(numAngle=numAngle, numRho=numRho)
- self.convs = nn.Sequential(
- nn.Conv2d(dim, dim, 3, 1, 1),
- nn.BatchNorm2d(dim),
- nn.ReLU(),
- nn.Conv2d(dim, dim, 3, 1, 1),
- nn.BatchNorm2d(dim),
- nn.ReLU()
- )
- def forward(self, x):
- x = self.fist_conv(x)
- x = self.dht(x)
- x = self.convs(x)
- return x
- class DHT(nn.Module):
- def __init__(self, numAngle, numRho):
- super(DHT, self).__init__()
- self.line_agg = C_dht(numAngle, numRho)
- def forward(self, x):
- accum = self.line_agg(x)
- return accum
|