[PYTHON] Cleaning C++ bindings

This commit is contained in:
Philippe Tillet
2020-11-02 15:06:08 -05:00
committed by Philippe Tillet
parent 34f1d5e565
commit 02a6e81b88
4 changed files with 26 additions and 27 deletions

View File

@@ -18,6 +18,15 @@ std::shared_ptr<drv::device> host_device;
std::shared_ptr<drv::context> host_context;
std::shared_ptr<drv::stream> host_stream;
int64_t cdiv_sum(torch::Tensor& x, int64_t div){
TORCH_CHECK(x.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::CPU), "Argument of cdiv_sum must be a CPU tensor")
auto _x = x.accessor<int, 1>();
int64_t ret = 0;
for(size_t i = 0; i < x.size(0); i++)
ret += (_x[i] + div - 1) / div;
return ret;
}
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
if(dev_id == -1){
if(!host_stream){
@@ -36,4 +45,6 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
}
static auto registry = torch::RegisterOperators("triton::launch_kernel", &launch_kernel);
static auto registry = torch::RegisterOperators()
.op("triton::launch_kernel", &launch_kernel)
.op("triton::cdiv_sum", &cdiv_sum);