cuda_src_forward = ''' __global__ void line_accum_forward_kernel( const float* __restrict__ feat, const float* tabCos, const float* tabSin, float* output, const int imWidth, const int imHeight, const int threadW, const int threadH, const int threadK, const int channelSize, const int batchSize, const int numangle, const int numrho) { int batch = blockIdx.y; int channel = blockIdx.x; int x = threadIdx.x*threadW; int y = threadIdx.y*threadH; int k = threadIdx.z*threadK; int imgStartIdx = batch*channelSize*imWidth*imHeight+ channel*imWidth*imHeight+ y*imWidth+ x; int angleStartIdx = k; if (x < imWidth && y < imHeight && channel < channelSize && batch < batchSize && k < numangle) { int imgIndex = imgStartIdx; int angleIndex; int outIndex; int r; for (int idY=0; idY < threadH; idY++) { imgIndex = imgStartIdx + idY*imWidth; // labelIndex = labelStartIdx + idY*imWidth; if (y+idY < imHeight) { for (int idX=0; idXsize, 0); line_accum_forward_kernel<<>>( in0_p, in1_p, in2_p, out0_p, in0_shape3, in0_shape2, threadW, threadH, threadK, in0_shape1, in0_shape0, #numangle, #numrho ); ''' cuda_src_backward = ''' __global__ void line_accum_backward_kernel( float* grad_in, const float* grad_out, const float* tabCos, const float* tabSin, const int imWidth, const int imHeight, const int threadW, const int threadH, const int threadK, const int channelSize, const int batchSize, const int numangle, const int numrho) { int batch = blockIdx.y; int channel = blockIdx.x; int x = threadIdx.x*threadW; int y = threadIdx.y*threadH; int k = threadIdx.z*threadK; int imgStartIdx = batch*channelSize*imWidth*imHeight+ channel*imWidth*imHeight+ y*imWidth+ x; int angleStartIdx = k; if (x < imWidth && y < imHeight && channel < channelSize && batch < batchSize && k < numangle) { int imgIndex = imgStartIdx; int angleIndex; int outIndex; int r; for (int idY=0; idY < threadH; idY++) { imgIndex = imgStartIdx + idY*imWidth; if (y+idY < imHeight) { for (int idX=0; idXsize, 0); line_accum_backward_kernel<<>>( out0_p, in1_p, in2_p, in3_p, in1_shape3, in1_shape2, threadW, threadH, threadK, in1_shape1, in1_shape0, #numangle, #numrho ); '''