From 6abe813d1c5690d79cbbac6ca9375627c19dd3ee Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 21 Sep 2022 10:20:48 -0700 Subject: [PATCH] Fix issue breaking cudagraphs (#685) @ngimel figured this one out. The errors we were seeing from cudagraphs capture were coming from `cuStreamGetCtx` which is not allowed while a stream is capturing. It appears the result of `cuStreamGetCtx()` isn't even used, so I believe it can just be removed. --- python/triton/compiler.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 67c32cef4..535f323da 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1026,13 +1026,10 @@ static inline void init_module(CUdevice device) {{ void _{kernel_name}(int gridX, int gridY, int gridZ, CUstream stream, {arg_decls}) {{ - CUcontext ctx; - CUdevice device; - CUDA_CHECK(cuStreamGetCtx(stream, &ctx)); - CUDA_CHECK(cuCtxGetDevice(&device)); - // TODO: machine may have heterogeneous devices if(function == 0){{ + CUdevice device; + CUDA_CHECK(cuCtxGetDevice(&device)); init_module(device); }} void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};