[FRONTEND] faster jit function launch (#523)
With fast (200 ns) get_stream function soon to be available from pytorch this shaves off approx 25-30 us from function launch, but even without that function due to caching device properties we are saving ~15-20us.
This commit is contained in:
committed by
GitHub
parent
d5eaa8dfa0
commit
96bff90471
@@ -23,12 +23,17 @@ import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .tools.disasm import extract
|
||||
|
||||
try:
|
||||
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
||||
except ImportError:
|
||||
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
||||
|
||||
|
||||
def current_cuda_stream(device_idx=0):
|
||||
# Torch's torch.cuda.current_stream() is slow. We provide this
|
||||
# function to give the user an opportunity to monkey-patch their
|
||||
# own faster current stream lookup.
|
||||
return torch.cuda.current_stream().cuda_stream
|
||||
return get_cuda_stream(device_idx)
|
||||
|
||||
|
||||
def mangle_ty(ty):
|
||||
@@ -910,6 +915,7 @@ class Kernel:
|
||||
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
self.cache_key = {}
|
||||
|
||||
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
@@ -951,12 +957,11 @@ class Kernel:
|
||||
# assert arg.is_cuda, "All tensors must be on GPU!"
|
||||
# set device (i.e., make sure torch has the context initialized)
|
||||
device = torch.cuda.current_device()
|
||||
torch.cuda.set_device(device)
|
||||
# query compute capability
|
||||
if device not in self.cache_key:
|
||||
cc = torch.cuda.get_device_capability(device)
|
||||
cc = str(cc[0]) + '-' + str(cc[1])
|
||||
cache_key = self.fn.cache_key + cc
|
||||
# query current stream
|
||||
self.cache_key[device] = self.fn.cache_key + cc
|
||||
cache_key = self.cache_key[device]
|
||||
stream = current_cuda_stream(device)
|
||||
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names,
|
||||
device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache,
|
||||
|
Reference in New Issue
Block a user