| #include <torch/library.h> |
|
|
| #include "registration.h" |
| #include "torch_binding.h" |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
| ops.def("fwd(" |
| "Tensor! q, " |
| "Tensor k, " |
| "Tensor v, " |
| "Tensor(out_!)? out_, " |
| "Tensor? alibi_slopes_, " |
| "float p_dropout, " |
| "float softmax_scale, " |
| "bool is_causal," |
| "int window_size_left, " |
| "int window_size_right, " |
| "float softcap, " |
| "bool return_softmax, " |
| "Generator? gen_) -> Tensor[]"); |
| ops.impl("fwd", torch::kCUDA, &mha_fwd); |
|
|
| ops.def("varlen_fwd(" |
| "Tensor! q, " |
| "Tensor k, " |
| "Tensor v, " |
| "Tensor? out_, " |
| "Tensor cu_seqlens_q, " |
| "Tensor cu_seqlens_k, " |
| "Tensor? seqused_k_, " |
| "Tensor? leftpad_k_, " |
| "Tensor? block_table_, " |
| "Tensor? alibi_slopes_, " |
| "int max_seqlen_q, " |
| "int max_seqlen_k, " |
| "float p_dropout, " |
| "float softmax_scale, " |
| "bool zero_tensors, " |
| "bool is_causal, " |
| "int window_size_left, " |
| "int window_size_right, " |
| "float softcap, " |
| "bool return_softmax, " |
| "Generator? gen_) -> Tensor[]"); |
| ops.impl("varlen_fwd", torch::kCUDA, &mha_varlen_fwd); |
|
|
| ops.def("bwd(" |
| "Tensor! dout, " |
| "Tensor! q, " |
| "Tensor! k, " |
| "Tensor! v, " |
| "Tensor! out, " |
| "Tensor! " |
| "softmax_lse, " |
| "Tensor? dq_, " |
| "Tensor? dk_, " |
| "Tensor? dv_, " |
| "Tensor? alibi_slopes_, " |
| "float p_dropout, " |
| "float softmax_scale, " |
| "bool is_causal, " |
| "int window_size_left, " |
| "int window_size_right, " |
| "float softcap, " |
| "bool deterministic, " |
| "Generator? gen_, " |
| "Tensor? rng_state) -> Tensor[]"); |
| ops.impl("bwd", torch::kCUDA, &mha_bwd); |
|
|
| ops.def("varlen_bwd(" |
| "Tensor! dout, " |
| "Tensor! q, " |
| "Tensor! k, " |
| "Tensor! v, " |
| "Tensor! out, " |
| "Tensor! softmax_lse, " |
| "Tensor? dq_, " |
| "Tensor? dk_, " |
| "Tensor? dv_, " |
| "Tensor cu_seqlens_q, " |
| "Tensor cu_seqlens_k, " |
| "Tensor? alibi_slopes_, " |
| "int max_seqlen_q, " |
| "int max_seqlen_k, " |
| "float p_dropout, float softmax_scale, " |
| "bool zero_tensors, " |
| "bool is_causal, " |
| "int window_size_left, " |
| "int window_size_right, " |
| "float softcap, " |
| "bool deterministic, " |
| "Generator? gen_, " |
| "Tensor? rng_state) -> Tensor[]"); |
| ops.impl("varlen_bwd", torch::kCUDA, &mha_varlen_bwd); |
|
|
| ops.def("fwd_kvcache(" |
| "Tensor! q, " |
| "Tensor! kcache, " |
| "Tensor! vcache, " |
| "Tensor? k_, " |
| "Tensor? v_, " |
| "Tensor? seqlens_k_, " |
| "Tensor? rotary_cos_, " |
| "Tensor? rotary_sin_, " |
| "Tensor? cache_batch_idx_, " |
| "Tensor? leftpad_k_, " |
| "Tensor? block_table_, " |
| "Tensor? alibi_slopes_, " |
| "Tensor? out_, " |
| "float softmax_scale, " |
| "bool is_causal, " |
| "int window_size_left, " |
| "int window_size_right, " |
| "float softcap, " |
| "bool is_rotary_interleaved, " |
| "int num_splits) -> Tensor[]"); |
| ops.impl("fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache); |
| } |
|
|
| REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
|
|