From 25f66895083982aa7c9a2ccf6600ebc1d9199d2b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 13 Apr 2022 11:45:55 -0700 Subject: [PATCH] [FRONTEND] rename current stream monkey patch (#495) --- python/triton/code_gen.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 311fb85f0..7553bee55 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -22,7 +22,12 @@ import triton import triton._C.libtriton.triton as _triton from .tools.disasm import extract -current_stream = lambda device: torch.cuda.current_stream(device).cuda_stream + +def current_cuda_stream(device_idx=0): + # Torch's torch.cuda.current_stream() is slow. We provide this + # function to give the user an opportunity to monkey-patch their + # own faster current stream lookup. + return torch.cuda.current_stream().cuda_stream def mangle_ty(ty): @@ -947,7 +952,7 @@ class Kernel: cc = str(cc[0]) + '-' + str(cc[1]) cache_key = self.fn.cache_key + cc # query current stream - stream = current_stream(device) + stream = current_cuda_stream(device) return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names, device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)