[FRONTEND] rename current stream monkey patch (#495)

This commit is contained in:
Philippe Tillet
2022-04-13 11:45:55 -07:00
committed by GitHub
parent 76bfac9f15
commit 25f6689508

View File

@@ -22,7 +22,12 @@ import triton
import triton._C.libtriton.triton as _triton
from .tools.disasm import extract
current_stream = lambda device: torch.cuda.current_stream(device).cuda_stream
def current_cuda_stream(device_idx=0):
# Torch's torch.cuda.current_stream() is slow. We provide this
# function to give the user an opportunity to monkey-patch their
# own faster current stream lookup.
return torch.cuda.current_stream().cuda_stream
def mangle_ty(ty):
@@ -947,7 +952,7 @@ class Kernel:
cc = str(cc[0]) + '-' + str(cc[1])
cache_key = self.fn.cache_key + cc
# query current stream
stream = current_stream(device)
stream = current_cuda_stream(device)
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names,
device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache,
grid)