[FRONTEND] faster jit function launch (#523)
With fast (200 ns) get_stream function soon to be available from pytorch this shaves off approx 25-30 us from function launch, but even without that function due to caching device properties we are saving ~15-20us.
This commit is contained in:
committed by
GitHub
parent
d5eaa8dfa0
commit
96bff90471
@@ -23,12 +23,17 @@ import triton
|
|||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
from .tools.disasm import extract
|
from .tools.disasm import extract
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
||||||
|
except ImportError:
|
||||||
|
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
||||||
|
|
||||||
|
|
||||||
def current_cuda_stream(device_idx=0):
|
def current_cuda_stream(device_idx=0):
|
||||||
# Torch's torch.cuda.current_stream() is slow. We provide this
|
# Torch's torch.cuda.current_stream() is slow. We provide this
|
||||||
# function to give the user an opportunity to monkey-patch their
|
# function to give the user an opportunity to monkey-patch their
|
||||||
# own faster current stream lookup.
|
# own faster current stream lookup.
|
||||||
return torch.cuda.current_stream().cuda_stream
|
return get_cuda_stream(device_idx)
|
||||||
|
|
||||||
|
|
||||||
def mangle_ty(ty):
|
def mangle_ty(ty):
|
||||||
@@ -910,6 +915,7 @@ class Kernel:
|
|||||||
|
|
||||||
def __init__(self, fn):
|
def __init__(self, fn):
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
|
self.cache_key = {}
|
||||||
|
|
||||||
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
||||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||||
@@ -951,12 +957,11 @@ class Kernel:
|
|||||||
# assert arg.is_cuda, "All tensors must be on GPU!"
|
# assert arg.is_cuda, "All tensors must be on GPU!"
|
||||||
# set device (i.e., make sure torch has the context initialized)
|
# set device (i.e., make sure torch has the context initialized)
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
torch.cuda.set_device(device)
|
if device not in self.cache_key:
|
||||||
# query compute capability
|
|
||||||
cc = torch.cuda.get_device_capability(device)
|
cc = torch.cuda.get_device_capability(device)
|
||||||
cc = str(cc[0]) + '-' + str(cc[1])
|
cc = str(cc[0]) + '-' + str(cc[1])
|
||||||
cache_key = self.fn.cache_key + cc
|
self.cache_key[device] = self.fn.cache_key + cc
|
||||||
# query current stream
|
cache_key = self.cache_key[device]
|
||||||
stream = current_cuda_stream(device)
|
stream = current_cuda_stream(device)
|
||||||
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names,
|
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names,
|
||||||
device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache,
|
device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache,
|
||||||
|
Reference in New Issue
Block a user