[FRONTEND] Added possibility for users to customize current stream query (#492)
This commit is contained in:
@@ -22,6 +22,8 @@ import triton
|
|||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
from .tools.disasm import extract
|
from .tools.disasm import extract
|
||||||
|
|
||||||
|
current_stream = lambda device: torch.cuda.current_stream(device).cuda_stream
|
||||||
|
|
||||||
|
|
||||||
def mangle_ty(ty):
|
def mangle_ty(ty):
|
||||||
if ty.is_ptr():
|
if ty.is_ptr():
|
||||||
@@ -787,6 +789,7 @@ class OutOfResources(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class Kernel:
|
class Kernel:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _type_name(obj):
|
def _type_name(obj):
|
||||||
type_names = {
|
type_names = {
|
||||||
@@ -915,28 +918,24 @@ class Kernel:
|
|||||||
raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given")
|
raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given")
|
||||||
# handle annotations
|
# handle annotations
|
||||||
for pos, _type in self.fn.annotations.items():
|
for pos, _type in self.fn.annotations.items():
|
||||||
|
assert _type == triton.language.constexpr, "only constexpr annotations are supported for now"
|
||||||
wargs[pos] = _type(wargs[pos])
|
wargs[pos] = _type(wargs[pos])
|
||||||
# check that tensors are on GPU.
|
# check that tensors are on GPU.
|
||||||
for arg in wargs:
|
for arg in wargs:
|
||||||
if hasattr(arg, 'data_ptr'):
|
if hasattr(arg, 'data_ptr'):
|
||||||
assert arg.is_cuda, "All tensors must be on GPU!"
|
assert arg.is_cuda, "All tensors must be on GPU!"
|
||||||
# query device index and cuda stream
|
# set device (i.e., make sure torch has the context initialized)
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
# query compute capability
|
||||||
cc = torch.cuda.get_device_capability(device)
|
cc = torch.cuda.get_device_capability(device)
|
||||||
cc = str(cc[0]) + '-' + str(cc[1])
|
cc = str(cc[0]) + '-' + str(cc[1])
|
||||||
# # query stream
|
cache_key = self.fn.cache_key + cc
|
||||||
# # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream`
|
# query current stream
|
||||||
# # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154
|
stream = current_stream(device)
|
||||||
# # building a C wrapper to re-use the unpack function would add a build-time torch dependency
|
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names,
|
||||||
# # and require different wheels for different torch versions -- undesirable!
|
device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache,
|
||||||
# bits = torch._C._cuda_getCurrentStream(device)
|
grid)
|
||||||
# mask = 1 << 47
|
|
||||||
# stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask
|
|
||||||
stream = torch.cuda.current_stream(device).cuda_stream
|
|
||||||
# make key for cache
|
|
||||||
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream,
|
|
||||||
self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
|
|
||||||
|
|
||||||
|
|
||||||
class Launcher:
|
class Launcher:
|
||||||
|
Reference in New Issue
Block a user