123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226 |
- #include <torch/extension.h>
- void multi_tensor_scale_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- float scale);
- void multi_tensor_sgd_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- float wd,
- float momentum,
- float dampening,
- float lr,
- bool nesterov,
- bool first_run,
- bool wd_after_momentum,
- float scale);
- void multi_tensor_axpby_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- float a,
- float b,
- int arg_to_check);
- std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- at::optional<bool> per_tensor_python);
- std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_mp_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- at::optional<bool> per_tensor_python);
- std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_scale_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- float scale,
- at::optional<bool> per_tensor_python);
- std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- at::Tensor inv_scale,
- at::optional<bool> per_tensor_python);
- void multi_tensor_lamb_stage1_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- at::Tensor per_tensor_decay,
- const int step,
- const float beta1,
- const float beta2,
- const float epsilon,
- at::Tensor global_grad_norm,
- const float max_global_grad_norm);
- void multi_tensor_lamb_stage2_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- at::Tensor per_tensor_param_norm,
- at::Tensor per_tensor_update_norm,
- const float lr,
- const float weight_decay,
- at::optional<bool> use_nvlamb_python);
- void multi_tensor_adam_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- const float lr,
- const float beta1,
- const float beta2,
- const float epsilon,
- const int step,
- const int mode,
- const int bias_correction,
- const float weight_decay);
- void multi_tensor_adam_capturable_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- at::Tensor lr,
- const float beta1,
- const float beta2,
- const float epsilon,
- at::Tensor step,
- const int mode,
- const int bias_correction,
- const float weight_decay,
- at::Tensor inv_scale);
- void multi_tensor_adam_capturable_master_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- at::Tensor lr,
- const float beta1,
- const float beta2,
- const float epsilon,
- at::Tensor step,
- const int mode,
- const int bias_correction,
- const float weight_decay,
- at::Tensor inv_scale);
- void multi_tensor_adagrad_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- const float lr,
- const float epsilon,
- const int mode,
- const float weight_decay);
- void multi_tensor_novograd_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- at::Tensor grad_norms,
- const float lr,
- const float beta1,
- const float beta2,
- const float epsilon,
- const int step,
- const int bias_correction,
- const float weight_decay,
- const int grad_averaging,
- const int mode,
- const int norm_type);
- void multi_tensor_lamb_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- const float lr,
- const float beta1,
- const float beta2,
- const float epsilon,
- const int step,
- const int bias_correction,
- const float weight_decay,
- const int grad_averaging,
- const int mode,
- at::Tensor global_grad_norm,
- const float max_grad_norm,
- at::optional<bool> use_nvlamb_python);
- void multi_tensor_lamb_mp_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- at::Tensor lr,
- const float beta1,
- const float beta2,
- const float epsilon,
- at::Tensor step,
- const int bias_correction,
- const float weight_decay,
- const int grad_averaging,
- const int mode,
- at::Tensor global_grad_norm,
- at::Tensor max_grad_norm,
- at::optional<bool> use_nvlamb_python,
- at::Tensor found_inf,
- at::Tensor inv_scale);
- at::Tensor update_scale_hysteresis_cuda(
- at::Tensor current_scale,
- at::Tensor growth_tracker,
- at::Tensor hysteresis_tracker,
- at::Tensor found_inf,
- const double growth_factor,
- const double backoff_factor,
- const int64_t growth_interval,
- const int hysteresis);
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
- "Fused overflow check + scale for a list of contiguous tensors");
- m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
- "Fused SGD optimizer for list of contiguous tensors");
- m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
- "out = a*x + b*y for a list of contiguous tensors");
- m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
- "Computes L2 norm for a list of contiguous tensors");
- m.def("multi_tensor_l2norm_mp", &multi_tensor_l2norm_mp_cuda,
- "Computes L2 norm for a list of contiguous tensors");
- m.def("multi_tensor_l2norm_scale", &multi_tensor_l2norm_scale_cuda,
- "Computes L2 norm for a list of contiguous tensors and does scaling");
- m.def("multi_tensor_unscale_l2norm", &multi_tensor_unscale_l2norm_cuda,
- "Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only performed for L2 norm computation, and tensors are not updated)");
- m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,
- "Computes update part of LAMB optimizer");
- m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda,
- "Completes application of gradient to parameters for LAMB optimizer");
- m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
- "Compute and apply gradient update to parameters for Adam optimizer");
- m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda,
- "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph support and LR scheduling");
- m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda,
- "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph support, LR scheduling and FP32 master weights");
- m.def("multi_tensor_adagrad", &multi_tensor_adagrad_cuda,
- "Compute and apply gradient update to parameters for Adam optimizer");
- m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda,
- "Compute and apply gradient update to parameters for Adam optimizer");
- m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
- "Computes and apply update for LAMB optimizer");
- m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda,
- "Computes and apply update for LAMB optimizer");
- m.def("update_scale_hysteresis", &update_scale_hysteresis_cuda,
- "Updates scale while accounting for hysteresis");
- }
|