flatten_unflatten.cpp 584 B

123456789101112131415161718
  1. #include <torch/extension.h>
  2. #include <torch/csrc/utils/tensor_flatten.h>
  3. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h
  4. at::Tensor flatten(std::vector<at::Tensor> tensors)
  5. {
  6. return torch::utils::flatten_dense_tensors(tensors);
  7. }
  8. std::vector<at::Tensor> unflatten(at::Tensor flat, std::vector<at::Tensor> tensors)
  9. {
  10. return torch::utils::unflatten_dense_tensors(flat, tensors);
  11. }
  12. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  13. m.def("flatten", &flatten, "Flatten dense tensors");
  14. m.def("unflatten", &unflatten, "Unflatten dense tensors");
  15. }