[GENERAL] Various bugfixes

This commit is contained in:
Philippe Tillet
2020-11-11 14:44:56 -05:00
committed by Philippe Tillet
parent 50587bbf4b
commit 8f8d36c7a4
11 changed files with 103 additions and 59 deletions

View File

@@ -51,11 +51,6 @@ std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt)
return id_fn_map[key]->ptx(&stream, opt);
}
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
pybind11::buffer_info info = data.request();
id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize);
}
void cleanup() {
id_grid_map.clear();
id_fn_map.clear();
@@ -134,7 +129,6 @@ PYBIND11_MODULE(libtriton, m) {
m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn);
m.def("register_cst", &register_cst);
m.def("delete_fn", &delete_fn);
m.def("make_op_id", &make_op_id);
m.def("cleanup", &cleanup);

View File

@@ -31,19 +31,25 @@ CUstream torch_get_cuda_stream(int64_t dev_id) {
return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
}
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
const std::vector<std::string>& constant_names, const std::vector<torch::Tensor>& constant_vals){
rt::function* fn = id_fn_map.at({op_id, dev_id}).get();
for(size_t n = 0; n < constant_names.size(); n++){
const torch::Tensor& x = constant_vals[n];
fn->set_cst(constant_names[n], (char*)x.data_ptr(), x.numel()*x.element_size());
}
if(dev_id == -1){
if(!host_stream){
host_device.reset(new drv::host_device());
host_context.reset(drv::context::create(&*host_device));
host_stream.reset(drv::stream::create(&*host_context));
}
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
}
else{
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
triton::driver::context* ctx = stream.context();
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
}
}