diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 67c32cef4..535f323da 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1026,13 +1026,10 @@ static inline void init_module(CUdevice device) {{ void _{kernel_name}(int gridX, int gridY, int gridZ, CUstream stream, {arg_decls}) {{ - CUcontext ctx; - CUdevice device; - CUDA_CHECK(cuStreamGetCtx(stream, &ctx)); - CUDA_CHECK(cuCtxGetDevice(&device)); - // TODO: machine may have heterogeneous devices if(function == 0){{ + CUdevice device; + CUDA_CHECK(cuCtxGetDevice(&device)); init_module(device); }} void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};