[GENERAL] Various bugfixes
This commit is contained in:
committed by
Philippe Tillet
parent
50587bbf4b
commit
8f8d36c7a4
@@ -31,19 +31,25 @@ CUstream torch_get_cuda_stream(int64_t dev_id) {
|
||||
return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
}
|
||||
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
||||
const std::vector<std::string>& constant_names, const std::vector<torch::Tensor>& constant_vals){
|
||||
rt::function* fn = id_fn_map.at({op_id, dev_id}).get();
|
||||
for(size_t n = 0; n < constant_names.size(); n++){
|
||||
const torch::Tensor& x = constant_vals[n];
|
||||
fn->set_cst(constant_names[n], (char*)x.data_ptr(), x.numel()*x.element_size());
|
||||
}
|
||||
if(dev_id == -1){
|
||||
if(!host_stream){
|
||||
host_device.reset(new drv::host_device());
|
||||
host_context.reset(drv::context::create(&*host_device));
|
||||
host_stream.reset(drv::stream::create(&*host_context));
|
||||
}
|
||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
|
||||
}
|
||||
else{
|
||||
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user