[FRONTEND] rename current stream monkey patch (#495)
This commit is contained in:
@@ -22,7 +22,12 @@ import triton
|
|||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
from .tools.disasm import extract
|
from .tools.disasm import extract
|
||||||
|
|
||||||
current_stream = lambda device: torch.cuda.current_stream(device).cuda_stream
|
|
||||||
|
def current_cuda_stream(device_idx=0):
|
||||||
|
# Torch's torch.cuda.current_stream() is slow. We provide this
|
||||||
|
# function to give the user an opportunity to monkey-patch their
|
||||||
|
# own faster current stream lookup.
|
||||||
|
return torch.cuda.current_stream().cuda_stream
|
||||||
|
|
||||||
|
|
||||||
def mangle_ty(ty):
|
def mangle_ty(ty):
|
||||||
@@ -947,7 +952,7 @@ class Kernel:
|
|||||||
cc = str(cc[0]) + '-' + str(cc[1])
|
cc = str(cc[0]) + '-' + str(cc[1])
|
||||||
cache_key = self.fn.cache_key + cc
|
cache_key = self.fn.cache_key + cc
|
||||||
# query current stream
|
# query current stream
|
||||||
stream = current_stream(device)
|
stream = current_cuda_stream(device)
|
||||||
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names,
|
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names,
|
||||||
device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache,
|
device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache,
|
||||||
grid)
|
grid)
|
||||||
|
Reference in New Issue
Block a user