[PYTHON] Added option to show PTX source code in Python
This commit is contained in:
committed by
Philippe Tillet
parent
cf80ccc798
commit
8f3ee53f24
@@ -27,6 +27,10 @@ int64_t cdiv_sum(torch::Tensor x, int64_t div){
|
||||
return ret;
|
||||
}
|
||||
|
||||
CUstream torch_get_cuda_stream(int64_t dev_id) {
|
||||
return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
}
|
||||
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
|
||||
if(dev_id == -1){
|
||||
if(!host_stream){
|
||||
@@ -37,8 +41,7 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
|
||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
|
||||
}
|
||||
else{
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
triton::driver::cu_stream stream(custream, false);
|
||||
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
||||
}
|
||||
|
Reference in New Issue
Block a user