diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index b44fc244e..51e3577ae 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -22,6 +22,8 @@ import triton import triton._C.libtriton.triton as _triton from .tools.disasm import extract +current_stream = lambda device: torch.cuda.current_stream(device).cuda_stream + def mangle_ty(ty): if ty.is_ptr(): @@ -787,6 +789,7 @@ class OutOfResources(Exception): class Kernel: + @staticmethod def _type_name(obj): type_names = { @@ -915,28 +918,24 @@ class Kernel: raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") # handle annotations for pos, _type in self.fn.annotations.items(): + assert _type == triton.language.constexpr, "only constexpr annotations are supported for now" wargs[pos] = _type(wargs[pos]) # check that tensors are on GPU. for arg in wargs: if hasattr(arg, 'data_ptr'): assert arg.is_cuda, "All tensors must be on GPU!" - # query device index and cuda stream + # set device (i.e., make sure torch has the context initialized) device = torch.cuda.current_device() torch.cuda.set_device(device) + # query compute capability cc = torch.cuda.get_device_capability(device) cc = str(cc[0]) + '-' + str(cc[1]) - # # query stream - # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` - # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 - # # building a C wrapper to re-use the unpack function would add a build-time torch dependency - # # and require different wheels for different torch versions -- undesirable! - # bits = torch._C._cuda_getCurrentStream(device) - # mask = 1 << 47 - # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask - stream = torch.cuda.current_stream(device).cuda_stream - # make key for cache - return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, - self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) + cache_key = self.fn.cache_key + cc + # query current stream + stream = current_stream(device) + return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names, + device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, + grid) class Launcher: