| #include <torch/library.h> |
|
|
| #include "registration.h" |
| #include "torch_binding.h" |
|
|
| TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
| ops.def("ms_deform_attn_forward(Tensor value, Tensor spatial_shapes," |
| " Tensor level_start_index, Tensor sampling_loc," |
| " Tensor attn_weight, int im2col_step) -> Tensor"); |
| ops.impl("ms_deform_attn_forward", torch::kCUDA, &ms_deform_attn_cuda_forward); |
|
|
| ops.def("ms_deform_attn_backward(Tensor value, Tensor spatial_shapes," |
| " Tensor level_start_index, Tensor sampling_loc," |
| " Tensor attn_weight, Tensor grad_output," |
| " int im2col_step) -> Tensor[]"); |
| ops.impl("ms_deform_attn_backward", torch::kCUDA, &ms_deform_attn_cuda_backward); |
| } |
|
|
| REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
|
|