diff --git a/python/src/launch.cc b/python/src/launch.cc index 5bb0bf669..1f6a8988f 100644 --- a/python/src/launch.cc +++ b/python/src/launch.cc @@ -7,6 +7,9 @@ #include "triton/tools/bench.hpp" #include "torch/script.h" #include "ATen/cuda/CUDAContext.h" +#include +#include + namespace rt = triton::runtime; namespace drv = triton::driver; @@ -69,6 +72,7 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, (*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream, &*host_device); } else{ + C10_CUDA_CHECK(cudaSetDevice(dev_id)); triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false); triton::driver::cu_device device(torch_get_cuda_device(dev_id), false); (*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream, &device);