#include #include #include #include size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features); template size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features); template int mlp_fp( T* X, int input_features, int batch_size, T** WPtr, int num_layers, int* output_features, T** BPtr, T* Y, T* reserved_space, int use_bias, int activation, void* lt_workspace); template int mlp_bp( T* X, T* Y, int input_features, int batch_size, T** WPtr, int num_layers, int* output_features, T* dY, T* reserved_space, T* work_space, T* dX, T** dwPtr, T** dbPtr, bool requires_grad, int use_bias, int activation); std::vector mlp_forward(int use_bias, int activation, std::vector inputs) { auto num_layers = inputs.size() - 1; if (use_bias) { // inputs contains (input, weights, biases) num_layers /= 2; } auto batch_size = inputs[0].size(0); auto input_features = inputs[0].size(1); std::vector output_features; for (int i = 0; i < num_layers; i++) { output_features.push_back(inputs[i + 1].size(0)); } auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); // create output/workspace tensor auto out = at::empty({batch_size, output_features.back()}, inputs[0].type()); auto reserved_space = at::empty({static_cast(reserved_size)}, inputs[0].type()); // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB auto lt_workspace = at::empty({1 << 22}, inputs[0].type()); AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] { std::vector w_ptr; std::vector b_ptr; for (int i = 0; i < num_layers; i++) { w_ptr.push_back(inputs[i + 1].data_ptr()); if (use_bias) { b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr()); } } auto result = mlp_fp( inputs[0].data_ptr(), input_features, batch_size, w_ptr.data(), num_layers, output_features.data(), b_ptr.data(), out.data_ptr(), reserved_space.data_ptr(), use_bias, activation, (void*) (lt_workspace.data_ptr())); }); return {out, reserved_space}; } std::vector mlp_backward( int use_bias, int activation, at::Tensor grad_o, std::vector fprop_outputs, std::vector inputs) { auto num_layers = inputs.size() - 1; if (use_bias) { // inputs contains (input, weights, biases) num_layers /= 2; } auto batch_size = inputs[0].size(0); auto input_features = inputs[0].size(1); bool requires_grad = inputs[0].requires_grad(); std::vector output_features; for (int i = 0; i < num_layers; i++) { output_features.push_back(inputs[i + 1].size(0)); } // create outputs, length of inputs std::vector outputs; for (int i = 0; i < inputs.size(); i++) { outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now } AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_backward", [&] { std::vector w_ptr; for (int i = 0; i < num_layers; i++) { w_ptr.push_back(inputs[i + 1].data_ptr()); } std::vector outputs_ptr; for (int i = 0; i < inputs.size(); i++) { outputs_ptr.push_back(outputs[i].data_ptr()); } auto work_size = get_mlp_bp_workspace_in_bytes(batch_size, num_layers, output_features.data()); // auto work_space = at::empty({work_size*4}, at::kByte); auto work_space = at::empty({static_cast(work_size / sizeof(scalar_t))}, inputs[0].type()); auto result = mlp_bp( inputs[0].data_ptr(), fprop_outputs[0].data_ptr(), input_features, batch_size, w_ptr.data(), num_layers, output_features.data(), grad_o.contiguous().data_ptr(), fprop_outputs[1].data_ptr(), work_space.data_ptr(), outputs_ptr[0], outputs_ptr.data() + 1, outputs_ptr.data() + 1 + num_layers, requires_grad, use_bias, activation); }); return outputs; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &mlp_forward, "MLP forward"); m.def("backward", &mlp_backward, "MLP backward"); }