[PYTHON] Context switching logic moved to PyTorch
This commit is contained in:
@@ -7,6 +7,9 @@
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "torch/script.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
@@ -69,6 +72,7 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream, &*host_device);
|
||||
}
|
||||
else{
|
||||
C10_CUDA_CHECK(cudaSetDevice(dev_id));
|
||||
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
||||
triton::driver::cu_device device(torch_get_cuda_device(dev_id), false);
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream, &device);
|
||||
|
Reference in New Issue
Block a user