diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index efcc2701f..cc0a103c9 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -493,6 +493,184 @@ class OutOfResources(Exception): super().__init__(self.message) +# class Kernel: +# @staticmethod +# def _type_name(obj): +# type_names = { +# triton.language.float8: 'f8', +# torch.bfloat16: 'bf16', +# torch.float16: 'f16', +# torch.float32: 'f32', +# torch.float64: 'f64', +# torch.bool: 'i1', +# torch.int8: 'i8', +# torch.int16: 'i16', +# torch.int32: 'i32', +# torch.int64: 'i64', +# } +# if hasattr(obj, 'data_ptr'): +# return type_names[obj.dtype] +# if isinstance(obj, triton.language.core.constexpr): +# obj = obj.value +# if isinstance(obj, int): +# if abs(obj) <= 0xffffffff: +# return 'I' +# return 'L' +# if isinstance(obj, float): +# return 'f' +# if isinstance(obj, bool): +# return 'B' +# if isinstance(obj, str): +# return 'str' +# assert False + + + +# @staticmethod +# def _to_triton_ir(context, obj): +# type_map = { +# 'I': _triton.ir.type.get_int32, +# 'L': _triton.ir.type.get_int64, +# 'f': _triton.ir.type.get_fp32, +# 'B': _triton.ir.type.get_int1, +# 'f8': _triton.ir.type.get_fp8, +# 'f16': _triton.ir.type.get_fp16, +# 'bf16': _triton.ir.type.get_bf16, +# 'f32': _triton.ir.type.get_fp32, +# 'f64': _triton.ir.type.get_fp64, +# 'i1': _triton.ir.type.get_int1, +# 'i8': _triton.ir.type.get_int8, +# 'i16': _triton.ir.type.get_int16, +# 'i32': _triton.ir.type.get_int32, +# 'i64': _triton.ir.type.get_int64, +# } +# # convert torch.Tensor to Triton IR pointers +# if hasattr(obj, 'data_ptr'): +# name = Kernel._type_name(obj) +# elt_ty = type_map[name](context) +# return _triton.ir.type.make_ptr(elt_ty, 1) +# # default path returns triton.ir.type directly +# name = Kernel._type_name(obj) +# return type_map[name](context) + +# @staticmethod +# def pow2_divisor(N): +# if N % 16 == 0: return 16 +# if N % 8 == 0: return 8 +# if N % 4 == 0: return 4 +# if N % 2 == 0: return 2 +# return 1 + +# def __init__(self, fn): +# self.fn = fn + +# def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages): +# wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)] +# # create IR module +# context = _triton.ir.context() +# # get just-in-time proto-type of kernel +# arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs] +# ret_type = _triton.ir.type.get_void(context) +# prototype = _triton.ir.type.make_function(ret_type, arg_types) +# # generate Triton-IR +# # export symbols visible from self.fn into code-generator object +# gscope = sys.modules[self.fn.module].__dict__ +# generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) +# try: +# generator.visit(self.fn.parse()) +# except Exception as e: +# node = generator.last_node +# if node is None or isinstance(e, (NotImplementedError, CompilationError)): +# raise e +# raise CompilationError(self.fn.src, node, e) +# # Compile to machine code +# if torch.version.hip is None: +# backend = _triton.runtime.backend.CUDA +# else: +# backend = _triton.runtime.backend.ROCM +# name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages) +# max_shared_memory = _triton.runtime.max_shared_memory(backend, device) +# if shared_mem > max_shared_memory: +# raise OutOfResources(shared_mem, max_shared_memory, "shared memory") +# return Binary(backend, name, asm, shared_mem, num_warps) + +# def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): +# tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] +# # attributes +# args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] +# attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ +# if isinstance(a, int) and i not in self.fn.do_not_specialize} + +# # transforms ints whose value is one into constants for just-in-time compilation +# constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} +# constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) +# hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() + +# # create cache directory +# cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') +# if cache_dir and not os.path.exists(cache_dir): +# os.makedirs(cache_dir, exist_ok=True) + +# if cache_dir: +# bin_cache_path = os.path.join(cache_dir, hashed_key) +# bin_lock_path = bin_cache_path + ".lock" +# else: +# bin_cache_path = None +# bin_lock_path = None + +# binary = None +# if bin_cache_path and os.path.exists(bin_cache_path): +# assert bin_lock_path is not None +# with FileLock(bin_lock_path): +# with open(bin_cache_path, 'rb') as f: +# binary = pickle.load(f)["binary"] +# if binary is None: +# binary = self._compile( +# *wargs, device=device_idx, attributes=attributes, +# num_warps=num_warps, num_stages=num_stages, +# constants=constants, +# ) +# if bin_cache_path: +# assert bin_lock_path is not None +# with FileLock(bin_lock_path): +# with open(bin_cache_path + ".tmp", "wb") as f: +# pickle.dump({"binary": binary, "key": key}, f) +# os.rename(bin_cache_path + ".tmp", bin_cache_path) +# if JITFunction.cache_hook is not None: +# JITFunction.cache_hook(key=key, binary=binary) + +# self.fn.bin_cache[key] = LoadedBinary(device_idx, binary) + +# def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): +# # handle arguments passed by name +# kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} +# wargs = list(wargs) +# for i, pos in enumerate(sorted(kwargs)): +# wargs.insert(pos + i, kwargs[pos]) +# if len(wargs) != len(self.fn.arg_names): +# raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") +# # handle annotations +# for pos, _type in self.fn.annotations.items(): +# wargs[pos] = _type(wargs[pos]) +# # query device index and cuda stream +# device = torch.cuda.current_device() +# torch.cuda.set_device(device) +# cc = torch.cuda.get_device_capability(device) +# cc = str(cc[0]) + '-' + str(cc[1]) +# # # query stream +# # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` +# # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 +# # # building a C wrapper to re-use the unpack function would add a build-time torch dependency +# # # and require different wheels for different torch versions -- undesirable! +# # bits = torch._C._cuda_getCurrentStream(device) +# # mask = 1 << 47 +# # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask +# stream = torch.cuda.current_stream(device).cuda_stream +# # make key for cache +# return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, +# self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) + + class Kernel: @staticmethod def _type_name(obj): @@ -553,6 +731,16 @@ class Kernel: name = Kernel._type_name(obj) return type_map[name](context) + @staticmethod + def _types_key(*wargs, tensor_idxs): + # type inference + types_key = [None] * len(wargs) + for i, arg in enumerate(wargs): + prefix = 'P' if i in tensor_idxs else '' + suffix = Kernel._type_name(arg) if i in tensor_idxs else Kernel._type_name(arg) + types_key[i] = prefix + suffix + return tuple(types_key) + @staticmethod def pow2_divisor(N): if N % 16 == 0: return 16 @@ -594,53 +782,6 @@ class Kernel: raise OutOfResources(shared_mem, max_shared_memory, "shared memory") return Binary(backend, name, asm, shared_mem, num_warps) - def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): - tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] - # attributes - args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] - attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ - if isinstance(a, int) and i not in self.fn.do_not_specialize} - - # transforms ints whose value is one into constants for just-in-time compilation - constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} - constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) - hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() - - # create cache directory - cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') - if cache_dir and not os.path.exists(cache_dir): - os.makedirs(cache_dir, exist_ok=True) - - if cache_dir: - bin_cache_path = os.path.join(cache_dir, hashed_key) - bin_lock_path = bin_cache_path + ".lock" - else: - bin_cache_path = None - bin_lock_path = None - - binary = None - if bin_cache_path and os.path.exists(bin_cache_path): - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path, 'rb') as f: - binary = pickle.load(f)["binary"] - if binary is None: - binary = self._compile( - *wargs, device=device_idx, attributes=attributes, - num_warps=num_warps, num_stages=num_stages, - constants=constants, - ) - if bin_cache_path: - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path + ".tmp", "wb") as f: - pickle.dump({"binary": binary, "key": key}, f) - os.rename(bin_cache_path + ".tmp", bin_cache_path) - if JITFunction.cache_hook is not None: - JITFunction.cache_hook(key=key, binary=binary) - - self.fn.bin_cache[key] = LoadedBinary(device_idx, binary) - def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): # handle arguments passed by name kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} @@ -650,25 +791,112 @@ class Kernel: if len(wargs) != len(self.fn.arg_names): raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") # handle annotations - for pos, _type in self.fn.annotations.items(): - wargs[pos] = _type(wargs[pos]) - # query device index and cuda stream - device = torch.cuda.current_device() - torch.cuda.set_device(device) - cc = torch.cuda.get_device_capability(device) - cc = str(cc[0]) + '-' + str(cc[1]) - # # query stream - # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` - # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 - # # building a C wrapper to re-use the unpack function would add a build-time torch dependency - # # and require different wheels for different torch versions -- undesirable! - # bits = torch._C._cuda_getCurrentStream(device) - # mask = 1 << 47 - # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask - stream = torch.cuda.current_stream(device).cuda_stream - # make key for cache - return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, - self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) + for name, type in self.fn.__annotations__.items(): + pos = self.fn.arg_names.index(name) + assert type == triton.language.core.constexpr + wargs[pos] = type(wargs[pos]) + # device inference + tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] + if len(tensor_idxs) == 0: + raise ValueError("No Tensor argument found.") + invalid_args = [] + device_ids = [] + for idx in tensor_idxs: + curr = wargs[idx] + if not curr.is_cuda: + invalid_args.append(idx) + else: + device_ids.append(curr.device.index) + if invalid_args: + raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) + + " Only CUDA is supported at the moment") + + device = torch.device('cuda', torch.cuda.current_device()) + device_idx = device.index + # if len(set(device_ids)) != 1 or device_ids[0] != device_idx: + # # try to enable P2P communication + # for arg_idx, dst_idx in zip(tensor_idxs, device_ids): + # if dst_idx != device_idx: + # try: + # _triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr()) + # except RuntimeError as e: + # raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}" + # .format(device_idx, dst_idx, str(e))) + + # enqueue kernel on the current device + torch.cuda.set_device(device_idx) + # attributes + args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] + attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ + if isinstance(a, int) and i not in self.fn.do_not_specialize} + + # transforms ints whose value is one into constants for just-in-time compilation + constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} + constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) + + # compute hash for caching this kernel + types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs) + attr_key = tuple(attributes.items()) + const_key = tuple(constants.items()) + compute_capability = torch.cuda.get_device_capability(device) + key = ( + self.fn.cache_key, version_key(), compute_capability, + types_key, attr_key, num_warps, num_stages, const_key + ) + key = repr(key) + + # get cached binary + bin_cache = self.fn.bin_cache + + if key not in bin_cache: + hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() + + # create cache directory + cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') + if cache_dir and not os.path.exists(cache_dir): + os.makedirs(cache_dir, exist_ok=True) + + if cache_dir: + bin_cache_path = os.path.join(cache_dir, hashed_key) + bin_lock_path = bin_cache_path + ".lock" + else: + bin_cache_path = None + bin_lock_path = None + + binary = None + if bin_cache_path and os.path.exists(bin_cache_path): + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path, 'rb') as f: + binary = pickle.load(f)["binary"] + if binary is None: + binary = self._compile( + *wargs, device=device_idx, attributes=attributes, + num_warps=num_warps, num_stages=num_stages, + constants=constants, + ) + if bin_cache_path: + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path + ".tmp", "wb") as f: + pickle.dump({"binary": binary, "key": key}, f) + os.rename(bin_cache_path + ".tmp", bin_cache_path) + if JITFunction.cache_hook is not None: + JITFunction.cache_hook(key=key, binary=binary) + + bin_cache[key] = LoadedBinary(device_idx, binary) + # pack arguments + fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)]) + params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)]) + # enqueue cached function into stream + callable = bin_cache[key] + stream = torch.cuda.current_stream(device_idx).cuda_stream + csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)} + grid = grid(csts) if hasattr(grid, '__call__') else grid + if isinstance(grid, int): + grid = tuple(grid) + callable(stream, params, *grid) + return callable class Launcher: