diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 73d114ed1..efcc2701f 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -657,15 +657,15 @@ class Kernel: torch.cuda.set_device(device) cc = torch.cuda.get_device_capability(device) cc = str(cc[0]) + '-' + str(cc[1]) - # query stream - # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` - # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 - # building a C wrapper to re-use the unpack function would add a build-time torch dependency - # and require different wheels for different torch versions -- undesirable! - bits = torch._C._cuda_getCurrentStream(device) - mask = 1 << 47 - stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask - # stream = torch.cuda.current_stream(device).cuda_stream + # # query stream + # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` + # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 + # # building a C wrapper to re-use the unpack function would add a build-time torch dependency + # # and require different wheels for different torch versions -- undesirable! + # bits = torch._C._cuda_getCurrentStream(device) + # mask = 1 << 47 + # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask + stream = torch.cuda.current_stream(device).cuda_stream # make key for cache return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)