Fix issue breaking cudagraphs (#685)

@ngimel figured this one out. 

The errors we were seeing from cudagraphs capture were coming from
`cuStreamGetCtx` which is not allowed while a stream is capturing.

It appears the result of `cuStreamGetCtx()` isn't even used, so I
believe it can just be removed.
This commit is contained in:
Jason Ansel
2022-09-21 10:20:48 -07:00
committed by GitHub
parent e318185eb4
commit 6abe813d1c

View File

@@ -1026,13 +1026,10 @@ static inline void init_module(CUdevice device) {{
void _{kernel_name}(int gridX, int gridY, int gridZ, CUstream stream, {arg_decls}) {{
CUcontext ctx;
CUdevice device;
CUDA_CHECK(cuStreamGetCtx(stream, &ctx));
CUDA_CHECK(cuCtxGetDevice(&device));
// TODO: machine may have heterogeneous devices
if(function == 0){{
CUdevice device;
CUDA_CHECK(cuCtxGetDevice(&device));
init_module(device);
}}
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};