From 96bff90471343ed01dd94effe91ecdaaa7a3f36a Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Tue, 24 May 2022 12:08:49 -0700 Subject: [PATCH] [FRONTEND] faster jit function launch (#523) With fast (200 ns) get_stream function soon to be available from pytorch this shaves off approx 25-30 us from function launch, but even without that function due to caching device properties we are saving ~15-20us. --- python/triton/code_gen.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 82ace0105..619e3109e 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -23,12 +23,17 @@ import triton import triton._C.libtriton.triton as _triton from .tools.disasm import extract +try: + from torch._C import _cuda_getCurrentRawStream as get_cuda_stream +except ImportError: + get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream + def current_cuda_stream(device_idx=0): # Torch's torch.cuda.current_stream() is slow. We provide this # function to give the user an opportunity to monkey-patch their # own faster current stream lookup. - return torch.cuda.current_stream().cuda_stream + return get_cuda_stream(device_idx) def mangle_ty(ty): @@ -910,6 +915,7 @@ class Kernel: def __init__(self, fn): self.fn = fn + self.cache_key = {} def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] @@ -951,12 +957,11 @@ class Kernel: # assert arg.is_cuda, "All tensors must be on GPU!" # set device (i.e., make sure torch has the context initialized) device = torch.cuda.current_device() - torch.cuda.set_device(device) - # query compute capability - cc = torch.cuda.get_device_capability(device) - cc = str(cc[0]) + '-' + str(cc[1]) - cache_key = self.fn.cache_key + cc - # query current stream + if device not in self.cache_key: + cc = torch.cuda.get_device_capability(device) + cc = str(cc[0]) + '-' + str(cc[1]) + self.cache_key[device] = self.fn.cache_key + cc + cache_key = self.cache_key[device] stream = current_cuda_stream(device) return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names, device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache,