Fix issue breaking cudagraphs (#685)
@ngimel figured this one out. The errors we were seeing from cudagraphs capture were coming from `cuStreamGetCtx` which is not allowed while a stream is capturing. It appears the result of `cuStreamGetCtx()` isn't even used, so I believe it can just be removed.
This commit is contained in:
@@ -1026,13 +1026,10 @@ static inline void init_module(CUdevice device) {{
|
|||||||
|
|
||||||
|
|
||||||
void _{kernel_name}(int gridX, int gridY, int gridZ, CUstream stream, {arg_decls}) {{
|
void _{kernel_name}(int gridX, int gridY, int gridZ, CUstream stream, {arg_decls}) {{
|
||||||
CUcontext ctx;
|
|
||||||
CUdevice device;
|
|
||||||
CUDA_CHECK(cuStreamGetCtx(stream, &ctx));
|
|
||||||
CUDA_CHECK(cuCtxGetDevice(&device));
|
|
||||||
|
|
||||||
// TODO: machine may have heterogeneous devices
|
// TODO: machine may have heterogeneous devices
|
||||||
if(function == 0){{
|
if(function == 0){{
|
||||||
|
CUdevice device;
|
||||||
|
CUDA_CHECK(cuCtxGetDevice(&device));
|
||||||
init_module(device);
|
init_module(device);
|
||||||
}}
|
}}
|
||||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||||
|
Reference in New Issue
Block a user