123456789101112131415161718192021222324252627282930313233343536373839 |
- import torch
- import deep_hough as dh
- import numpy as np
- import matplotlib.pyplot as plt
- import time
- class C_dht_Function(torch.autograd.Function):
- @staticmethod
- def forward(ctx, feat, numangle, numrho):
- N, C, _, _ = feat.size()
- out = torch.zeros(N, C, numangle, numrho).type_as(feat).cuda()
- out = dh.forward(feat, out, numangle, numrho)
- outputs = out[0]
- ctx.save_for_backward(feat)
- ctx.numangle = numangle
- ctx.numrho = numrho
- return outputs
-
- @staticmethod
- def backward(ctx, grad_output):
- feat = ctx.saved_tensors[0]
- numangle = ctx.numangle
- numrho = ctx.numrho
- out = torch.zeros_like(feat).type_as(feat).cuda()
- out = dh.backward(grad_output.contiguous(), out, feat, numangle, numrho)
- grad_in = out[0]
- return grad_in, None, None
- class C_dht(torch.nn.Module):
- def __init__(self, numAngle, numRho):
- super(C_dht, self).__init__()
- self.numAngle = numAngle
- self.numRho = numRho
-
- def forward(self, feat):
- return C_dht_Function.apply(feat, self.numAngle, self.numRho)
|