[FRONTEND] faster jit function launch (#523)

With fast (200 ns) get_stream function soon to be available from pytorch this shaves off approx 25-30 us from function launch, but even without that function due to caching device properties we are saving ~15-20us.
This commit is contained in:
Natalia Gimelshein
2022-05-24 12:08:49 -07:00
committed by GitHub
parent d5eaa8dfa0
commit 96bff90471

View File

@@ -23,12 +23,17 @@ import triton
import triton._C.libtriton.triton as _triton
from .tools.disasm import extract
try:
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
except ImportError:
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
def current_cuda_stream(device_idx=0):
# Torch's torch.cuda.current_stream() is slow. We provide this
# function to give the user an opportunity to monkey-patch their
# own faster current stream lookup.
return torch.cuda.current_stream().cuda_stream
return get_cuda_stream(device_idx)
def mangle_ty(ty):
@@ -910,6 +915,7 @@ class Kernel:
def __init__(self, fn):
self.fn = fn
self.cache_key = {}
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
@@ -951,12 +957,11 @@ class Kernel:
# assert arg.is_cuda, "All tensors must be on GPU!"
# set device (i.e., make sure torch has the context initialized)
device = torch.cuda.current_device()
torch.cuda.set_device(device)
# query compute capability
cc = torch.cuda.get_device_capability(device)
cc = str(cc[0]) + '-' + str(cc[1])
cache_key = self.fn.cache_key + cc
# query current stream
if device not in self.cache_key:
cc = torch.cuda.get_device_capability(device)
cc = str(cc[0]) + '-' + str(cc[1])
self.cache_key[device] = self.fn.cache_key + cc
cache_key = self.cache_key[device]
stream = current_cuda_stream(device)
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names,
device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache,