#include #include "compat.h" // Forward/backward compatiblity hack around // https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 // pending more future-proof guidance from upstream. // struct TypeShim // { // const at::Type& payload; // TypeShim(const at::Type& type) : payload(type) {} // // Enable trivial conversion to a const at::Type& for pre-3aeb78 // operator const at::Type&(){ return payload; }; // // Enable dispatch switch statements to take *this directly for post-3aeb78 // //operator at::ScalarType(){ return payload.; }; // }; #define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ switch(TYPE) \ { \ case at::ScalarType::Float: \ { \ using scalar_t_##LEVEL = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: \ { \ using scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \ switch(TYPE) \ { \ case at::ScalarType::Float: \ { \ using scalar_t_##LEVEL = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: \ { \ using scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: \ { \ using scalar_t_##LEVEL = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } #define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ switch(TYPE) \ { \ case at::ScalarType::Float: \ { \ using scalar_t_##LEVEL = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: \ { \ using scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Byte: \ { \ using scalar_t_##LEVEL = uint8_t; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } #define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ switch(TYPE) \ { \ case at::ScalarType::Double: \ { \ using scalar_t_##LEVEL = double; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Float: \ { \ using scalar_t_##LEVEL = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: \ { \ using scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } #define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \ switch(TYPE) \ { \ case at::ScalarType::Double: \ { \ using scalar_t_##LEVEL = double; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Float: \ { \ using scalar_t_##LEVEL = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: \ { \ using scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: \ { \ using scalar_t_##LEVEL = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } #define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ switch(TYPE) \ { \ case at::ScalarType::Double: \ { \ using scalar_t_##LEVEL = double; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Float: \ { \ using scalar_t_##LEVEL = float; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ switch(TYPE) \ { \ case at::ScalarType::Half: \ { \ using scalar_t = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: \ { \ using scalar_t = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ switch(TYPEIN) \ { \ case at::ScalarType::Float: \ { \ using scalar_t_in = float; \ switch(TYPEOUT) \ { \ case at::ScalarType::Float: \ { \ using scalar_t_out = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: \ { \ using scalar_t_out = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: \ { \ using scalar_t_out = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ } \ break; \ } \ case at::ScalarType::Half: \ { \ using scalar_t_in = at::Half; \ using scalar_t_out = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: \ { \ using scalar_t_in = at::BFloat16; \ using scalar_t_out = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ } #define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ switch(TYPEIN) \ { \ case at::ScalarType::Double: \ { \ using scalar_t_in = double; \ switch(TYPEOUT) \ { \ case at::ScalarType::Double: \ { \ using scalar_t_out = double; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Float: \ { \ using scalar_t_out = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: \ { \ using scalar_t_out = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: \ { \ using scalar_t_out = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ } \ break; \ } \ case at::ScalarType::Float: \ { \ using scalar_t_in = float; \ switch(TYPEOUT) \ { \ case at::ScalarType::Float: \ { \ using scalar_t_out = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: \ { \ using scalar_t_out = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: \ { \ using scalar_t_out = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ } \ break; \ } \ case at::ScalarType::Half: \ { \ using scalar_t_in = at::Half; \ using scalar_t_out = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: \ { \ using scalar_t_in = at::BFloat16; \ using scalar_t_out = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ } template __device__ __forceinline__ T reduce_block_into_lanes (T *x, T val, int lanes=1, bool share_result=false) // lanes is intended to be <= 32. { int tid = threadIdx.x + threadIdx.y*blockDim.x; int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32. if(blockSize >= 64) { x[tid] = val; __syncthreads(); } #pragma unroll for(int i = (blockSize >> 1); i >= 64; i >>= 1) { if(tid < i) x[tid] = x[tid] + x[tid+i]; __syncthreads(); } T final; if(tid < 32) { if(blockSize >= 64) final = x[tid] + x[tid+32]; else final = val; // __SYNCWARP(); #pragma unroll for(int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i); } if(share_result) { if(tid < lanes) x[tid] = final; // EpilogueOp // Make sure the smem result is visible to all warps. __syncthreads(); } return final; } template __device__ __forceinline__ T reduce_block_into_lanes_max_op (T *x, T val, int lanes=1, bool share_result=false) // lanes is intended to be <= 32. { int tid = threadIdx.x + threadIdx.y*blockDim.x; int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32. if(blockSize >= 64) { x[tid] = val; __syncthreads(); } #pragma unroll for(int i = (blockSize >> 1); i >= 64; i >>= 1) { if(tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i])); __syncthreads(); } T final; if(tid < 32) { if(blockSize >= 64) final = fmaxf(fabsf(x[tid]), fabsf(x[tid+32])); else final = val; // __SYNCWARP(); #pragma unroll for(int i = 16; i >= lanes; i >>= 1) final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); } if(share_result) { if(tid < lanes) x[tid] = final; // EpilogueOp // Make sure the smem result is visible to all warps. __syncthreads(); } return final; }