123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- import jittor as jt
- import numpy as np
- from jittor import nn
- from model.backbone.fpn import FPN50
- from model.dht import DHT_Layer
- class Net(nn.Module):
- def __init__(self, numAngle, numRho, backbone):
- super(Net, self).__init__()
- if backbone == 'resnet50':
- print('using resnet50 backbone.')
- self.backbone = FPN50(pretrained=True, output_stride=16)
- output_stride = 16
- else:
- raise NotImplementedError
-
- self.dht_detector1 = DHT_Layer(256, 128, numAngle=numAngle, numRho=numRho)
- self.dht_detector2 = DHT_Layer(256, 128, numAngle=numAngle, numRho=numRho // 2)
- self.dht_detector3 = DHT_Layer(256, 128, numAngle=numAngle, numRho=numRho // 4)
- self.dht_detector4 = DHT_Layer(256, 128, numAngle=numAngle, numRho=numRho // (output_stride // 4))
-
- self.last_conv = nn.Sequential(
- nn.Conv2d(512, 1, 1)
- )
- self.numAngle = numAngle
- self.numRho = numRho
- def upsample_cat(self, p1, p2, p3, p4):
- p1 = nn.interpolate(p1, size=(self.numAngle, self.numRho), mode='bilinear', align_corners=True)
- p2 = nn.interpolate(p2, size=(self.numAngle, self.numRho), mode='bilinear', align_corners=True)
- p3 = nn.interpolate(p3, size=(self.numAngle, self.numRho), mode='bilinear', align_corners=True)
- p4 = nn.interpolate(p4, size=(self.numAngle, self.numRho), mode='bilinear', align_corners=True)
- return jt.concat([p1, p2, p3, p4], dim=1)
- def execute(self, x):
- p1, p2, p3, p4 = self.backbone(x)
- p4 = self.dht_detector4(p4)
- p3 = self.dht_detector3(p3)
- p2 = self.dht_detector2(p2)
- p1 = self.dht_detector1(p1)
-
- cat = self.upsample_cat(p1, p2, p3, p4)
- logist = self.last_conv(cat)
- return logist
|