From 5693b582eac1002c19039cefffa2f70ec747bb77 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 21 Nov 2021 02:30:22 -0800 Subject: [PATCH] [RUNTIME] Now using pybind11 to avoid memory leaks (#377) --- python/src/triton.cc | 85 ++++----- python/triton/code_gen.py | 362 +++++++------------------------------- 2 files changed, 110 insertions(+), 337 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 5fb6bb8f2..2c165dd06 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -110,18 +110,19 @@ std::string pow2_divisor(long N){ } // Launch -void parse_args(py::handle& args, py::handle do_not_specialize, const std::string& func_key, py::handle& arg_names, - std::string& cache_key, std::string& params, size_t& params_size, PyObject* constants, +void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, + std::string& cache_key, std::string& params, size_t& params_size, py::dict constants, int num_warps, int num_stages) { size_t len = PyList_Size(args.ptr()); params.reserve(8*len); // 8 max bytes by argument char* params_ptr = ¶ms[0]; cache_key = func_key; for(int i = 0; i < len; i++){ - PyObject* py_i = PyLong_FromLong(i); - bool specialize = !PySequence_Contains(do_not_specialize.ptr(), py_i); - auto arg_ptr = PyList_GetItem(args.ptr(), i); - auto arg = py::handle(arg_ptr); + py::int_ py_i = py::int_(i); + bool specialize = std::find(do_not_specialize.begin(), do_not_specialize.end(), py_i) == do_not_specialize.end(); + py::object arg = args[i]; + auto arg_ptr = arg.ptr(); + // argument is `long` if(PyLong_Check(arg_ptr)){ int overflow; @@ -172,28 +173,28 @@ void parse_args(py::handle& args, py::handle do_not_specialize, const std::strin continue; } // argument is tensor - PyObject* data_ptr = PyObject_CallMethod(arg_ptr, "data_ptr", nullptr); - if(data_ptr){ + if(py::hasattr(arg, "data_ptr")){ + py::object data_ptr = arg.attr("data_ptr")(); cache_key += "P"; - long value = PyLong_AsLong(data_ptr); + long value = data_ptr.cast(); params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); std::memcpy(params_ptr, &value, 8); params_ptr += 8; - PyObject* dtype = PyObject_GetAttrString(arg_ptr, "dtype"); - PyObject* repr = PyObject_Repr(dtype); - const char* start = (const char*)PyUnicode_1BYTE_DATA(repr) + 6; // remove 'torch.' - size_t len = PyUnicode_GET_LENGTH(repr) - 6; + py::object dtype = arg.attr("dtype"); + py::object repr = py::repr(dtype); + const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.' + size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6; cache_key += std::string(start, len); continue; } // argument is `constexpr` - PyObject* value = PyObject_GetAttrString(arg_ptr, "value"); + py::object value = arg.attr("value"); if(value){ - PyObject* name = PyList_GetItem(arg_names.ptr(), i); - PyDict_SetItem(constants, name, value); - PyObject* repr = PyObject_Repr(value); - const char* start = (const char*)PyUnicode_1BYTE_DATA(repr); - size_t len = PyUnicode_GET_LENGTH(repr); + py::object name = arg_names[i]; + constants[name] = value; + py::object repr = py::repr(value); + const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()); + size_t len = PyUnicode_GET_LENGTH(repr.ptr()); cache_key += std::string(start, len); continue; } @@ -228,37 +229,39 @@ void init_triton_runtime(py::module &&m) { ); // cache key - m.def("launch", [](py::handle args, py::handle do_not_specialize, const std::string& func_key, py::list& arg_names, - py::handle device, py::handle stream, py::handle bin_cache, py::handle num_warps, py::handle num_stages, - py::handle add_to_cache, py::handle grid){ + m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, + py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages, + py::function add_to_cache, py::object grid){ // parse arguments to compute cache key, compile-time constants and packed kernel arguments long _num_warps = PyLong_AsLong(num_warps.ptr()); long _num_stages = PyLong_AsLong(num_stages.ptr()); std::string cache_key; std::string params; size_t params_size; - PyObject* constants = PyDict_New(); + py::dict constants; parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages); + // get cached binary - PyObject* key = PyUnicode_FromString(cache_key.c_str()); - PyObject* bin = nullptr; - if(!PyDict_Contains(bin_cache.ptr(), key)){ - add_to_cache(py::handle(key), args, device, num_warps, num_stages); - } - bin = PyDict_GetItem(bin_cache.ptr(), key); + py::str key(cache_key); + if(!bin_cache.contains(key)) + add_to_cache(key, args, device, num_warps, num_stages); + py::object bin = bin_cache[key]; + // get grid - PyObject* grid_ptr = grid.ptr(); - if(!PySequence_Check(grid_ptr)){ - PyObject* grid_call = PyObject_GetAttrString(grid_ptr, "__call__"); - grid_ptr = PyObject_Call(grid_call, PyTuple_Pack(1, constants), nullptr); - } - int size = PySequence_Size(grid_ptr); - int grid_0 = PyLong_AsLong(PySequence_GetItem(grid_ptr, 0)); - int grid_1 = size < 2 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 1)); - int grid_2 = size < 3 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 2)); + py::sequence seq; + if(!PySequence_Check(grid.ptr())) + seq = grid(constants); + else + seq = grid; + int size = seq.size(); + int grid_0 = py::cast(seq[0]); + int grid_1 = size < 2 ? 1 : py::cast(seq[1]); + int grid_2 = size < 3 ? 1 : py::cast(seq[2]); + // enqueue - uint64_t kernel = PyLong_AsLong(PyObject_GetAttrString(bin, "kernel")); - uint64_t shared_mem = PyLong_AsLong(PyObject_GetAttrString(bin, "shared_mem")); + uint64_t kernel = py::cast(bin.attr("kernel")); + uint64_t shared_mem = py::cast(bin.attr("shared_mem")); + // actually launch void *config[] = { CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(), @@ -269,7 +272,7 @@ void init_triton_runtime(py::module &&m) { drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, nullptr, config); - return py::handle(bin); + return bin; }); // query maximum shared memory diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 1e514d3e2..efcc2701f 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -493,184 +493,6 @@ class OutOfResources(Exception): super().__init__(self.message) -# class Kernel: -# @staticmethod -# def _type_name(obj): -# type_names = { -# triton.language.float8: 'f8', -# torch.bfloat16: 'bf16', -# torch.float16: 'f16', -# torch.float32: 'f32', -# torch.float64: 'f64', -# torch.bool: 'i1', -# torch.int8: 'i8', -# torch.int16: 'i16', -# torch.int32: 'i32', -# torch.int64: 'i64', -# } -# if hasattr(obj, 'data_ptr'): -# return type_names[obj.dtype] -# if isinstance(obj, triton.language.core.constexpr): -# obj = obj.value -# if isinstance(obj, int): -# if abs(obj) <= 0xffffffff: -# return 'I' -# return 'L' -# if isinstance(obj, float): -# return 'f' -# if isinstance(obj, bool): -# return 'B' -# if isinstance(obj, str): -# return 'str' -# assert False - - - -# @staticmethod -# def _to_triton_ir(context, obj): -# type_map = { -# 'I': _triton.ir.type.get_int32, -# 'L': _triton.ir.type.get_int64, -# 'f': _triton.ir.type.get_fp32, -# 'B': _triton.ir.type.get_int1, -# 'f8': _triton.ir.type.get_fp8, -# 'f16': _triton.ir.type.get_fp16, -# 'bf16': _triton.ir.type.get_bf16, -# 'f32': _triton.ir.type.get_fp32, -# 'f64': _triton.ir.type.get_fp64, -# 'i1': _triton.ir.type.get_int1, -# 'i8': _triton.ir.type.get_int8, -# 'i16': _triton.ir.type.get_int16, -# 'i32': _triton.ir.type.get_int32, -# 'i64': _triton.ir.type.get_int64, -# } -# # convert torch.Tensor to Triton IR pointers -# if hasattr(obj, 'data_ptr'): -# name = Kernel._type_name(obj) -# elt_ty = type_map[name](context) -# return _triton.ir.type.make_ptr(elt_ty, 1) -# # default path returns triton.ir.type directly -# name = Kernel._type_name(obj) -# return type_map[name](context) - -# @staticmethod -# def pow2_divisor(N): -# if N % 16 == 0: return 16 -# if N % 8 == 0: return 8 -# if N % 4 == 0: return 4 -# if N % 2 == 0: return 2 -# return 1 - -# def __init__(self, fn): -# self.fn = fn - -# def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages): -# wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)] -# # create IR module -# context = _triton.ir.context() -# # get just-in-time proto-type of kernel -# arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs] -# ret_type = _triton.ir.type.get_void(context) -# prototype = _triton.ir.type.make_function(ret_type, arg_types) -# # generate Triton-IR -# # export symbols visible from self.fn into code-generator object -# gscope = sys.modules[self.fn.module].__dict__ -# generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) -# try: -# generator.visit(self.fn.parse()) -# except Exception as e: -# node = generator.last_node -# if node is None or isinstance(e, (NotImplementedError, CompilationError)): -# raise e -# raise CompilationError(self.fn.src, node, e) -# # Compile to machine code -# if torch.version.hip is None: -# backend = _triton.runtime.backend.CUDA -# else: -# backend = _triton.runtime.backend.ROCM -# name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages) -# max_shared_memory = _triton.runtime.max_shared_memory(backend, device) -# if shared_mem > max_shared_memory: -# raise OutOfResources(shared_mem, max_shared_memory, "shared memory") -# return Binary(backend, name, asm, shared_mem, num_warps) - -# def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): -# tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] -# # attributes -# args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] -# attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ -# if isinstance(a, int) and i not in self.fn.do_not_specialize} - -# # transforms ints whose value is one into constants for just-in-time compilation -# constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} -# constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) -# hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() - -# # create cache directory -# cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') -# if cache_dir and not os.path.exists(cache_dir): -# os.makedirs(cache_dir, exist_ok=True) - -# if cache_dir: -# bin_cache_path = os.path.join(cache_dir, hashed_key) -# bin_lock_path = bin_cache_path + ".lock" -# else: -# bin_cache_path = None -# bin_lock_path = None - -# binary = None -# if bin_cache_path and os.path.exists(bin_cache_path): -# assert bin_lock_path is not None -# with FileLock(bin_lock_path): -# with open(bin_cache_path, 'rb') as f: -# binary = pickle.load(f)["binary"] -# if binary is None: -# binary = self._compile( -# *wargs, device=device_idx, attributes=attributes, -# num_warps=num_warps, num_stages=num_stages, -# constants=constants, -# ) -# if bin_cache_path: -# assert bin_lock_path is not None -# with FileLock(bin_lock_path): -# with open(bin_cache_path + ".tmp", "wb") as f: -# pickle.dump({"binary": binary, "key": key}, f) -# os.rename(bin_cache_path + ".tmp", bin_cache_path) -# if JITFunction.cache_hook is not None: -# JITFunction.cache_hook(key=key, binary=binary) - -# self.fn.bin_cache[key] = LoadedBinary(device_idx, binary) - -# def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): -# # handle arguments passed by name -# kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} -# wargs = list(wargs) -# for i, pos in enumerate(sorted(kwargs)): -# wargs.insert(pos + i, kwargs[pos]) -# if len(wargs) != len(self.fn.arg_names): -# raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") -# # handle annotations -# for pos, _type in self.fn.annotations.items(): -# wargs[pos] = _type(wargs[pos]) -# # query device index and cuda stream -# device = torch.cuda.current_device() -# torch.cuda.set_device(device) -# cc = torch.cuda.get_device_capability(device) -# cc = str(cc[0]) + '-' + str(cc[1]) -# # # query stream -# # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` -# # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 -# # # building a C wrapper to re-use the unpack function would add a build-time torch dependency -# # # and require different wheels for different torch versions -- undesirable! -# # bits = torch._C._cuda_getCurrentStream(device) -# # mask = 1 << 47 -# # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask -# stream = torch.cuda.current_stream(device).cuda_stream -# # make key for cache -# return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, -# self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) - - class Kernel: @staticmethod def _type_name(obj): @@ -700,8 +522,6 @@ class Kernel: return 'B' if isinstance(obj, str): return 'str' - if isinstance(obj, JITFunction): - return '' assert False @@ -733,16 +553,6 @@ class Kernel: name = Kernel._type_name(obj) return type_map[name](context) - @staticmethod - def _types_key(*wargs, tensor_idxs): - # type inference - types_key = [None] * len(wargs) - for i, arg in enumerate(wargs): - prefix = 'P' if i in tensor_idxs else '' - suffix = Kernel._type_name(arg) if i in tensor_idxs else Kernel._type_name(arg) - types_key[i] = prefix + suffix - return tuple(types_key) - @staticmethod def pow2_divisor(N): if N % 16 == 0: return 16 @@ -784,6 +594,53 @@ class Kernel: raise OutOfResources(shared_mem, max_shared_memory, "shared memory") return Binary(backend, name, asm, shared_mem, num_warps) + def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): + tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] + # attributes + args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] + attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ + if isinstance(a, int) and i not in self.fn.do_not_specialize} + + # transforms ints whose value is one into constants for just-in-time compilation + constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} + constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) + hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() + + # create cache directory + cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') + if cache_dir and not os.path.exists(cache_dir): + os.makedirs(cache_dir, exist_ok=True) + + if cache_dir: + bin_cache_path = os.path.join(cache_dir, hashed_key) + bin_lock_path = bin_cache_path + ".lock" + else: + bin_cache_path = None + bin_lock_path = None + + binary = None + if bin_cache_path and os.path.exists(bin_cache_path): + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path, 'rb') as f: + binary = pickle.load(f)["binary"] + if binary is None: + binary = self._compile( + *wargs, device=device_idx, attributes=attributes, + num_warps=num_warps, num_stages=num_stages, + constants=constants, + ) + if bin_cache_path: + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path + ".tmp", "wb") as f: + pickle.dump({"binary": binary, "key": key}, f) + os.rename(bin_cache_path + ".tmp", bin_cache_path) + if JITFunction.cache_hook is not None: + JITFunction.cache_hook(key=key, binary=binary) + + self.fn.bin_cache[key] = LoadedBinary(device_idx, binary) + def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): # handle arguments passed by name kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} @@ -793,112 +650,25 @@ class Kernel: if len(wargs) != len(self.fn.arg_names): raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") # handle annotations - for name, type in self.fn.__annotations__.items(): - pos = self.fn.arg_names.index(name) - assert type == triton.language.core.constexpr - wargs[pos] = type(wargs[pos]) - # device inference - tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] - if len(tensor_idxs) == 0: - raise ValueError("No Tensor argument found.") - invalid_args = [] - device_ids = [] - for idx in tensor_idxs: - curr = wargs[idx] - if not curr.is_cuda: - invalid_args.append(idx) - else: - device_ids.append(curr.device.index) - if invalid_args: - raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) + - " Only CUDA is supported at the moment") - - device = torch.device('cuda', torch.cuda.current_device()) - device_idx = device.index - # if len(set(device_ids)) != 1 or device_ids[0] != device_idx: - # # try to enable P2P communication - # for arg_idx, dst_idx in zip(tensor_idxs, device_ids): - # if dst_idx != device_idx: - # try: - # _triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr()) - # except RuntimeError as e: - # raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}" - # .format(device_idx, dst_idx, str(e))) - - # enqueue kernel on the current device - torch.cuda.set_device(device_idx) - # attributes - args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] - attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ - if isinstance(a, int) and i not in self.fn.do_not_specialize} - - # transforms ints whose value is one into constants for just-in-time compilation - constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} - constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) - - # compute hash for caching this kernel - types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs) - attr_key = tuple(attributes.items()) - const_key = tuple(constants.items()) - compute_capability = torch.cuda.get_device_capability(device) - key = ( - self.fn.cache_key, version_key(), compute_capability, - types_key, attr_key, num_warps, num_stages, const_key - ) - key = repr(key) - - # get cached binary - bin_cache = self.fn.bin_cache - - if key not in bin_cache: - hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() - - # create cache directory - cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') - if cache_dir and not os.path.exists(cache_dir): - os.makedirs(cache_dir, exist_ok=True) - - if cache_dir: - bin_cache_path = os.path.join(cache_dir, hashed_key) - bin_lock_path = bin_cache_path + ".lock" - else: - bin_cache_path = None - bin_lock_path = None - - binary = None - if bin_cache_path and os.path.exists(bin_cache_path): - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path, 'rb') as f: - binary = pickle.load(f)["binary"] - if binary is None: - binary = self._compile( - *wargs, device=device_idx, attributes=attributes, - num_warps=num_warps, num_stages=num_stages, - constants=constants, - ) - if bin_cache_path: - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path + ".tmp", "wb") as f: - pickle.dump({"binary": binary, "key": key}, f) - os.rename(bin_cache_path + ".tmp", bin_cache_path) - if JITFunction.cache_hook is not None: - JITFunction.cache_hook(key=key, binary=binary) - - bin_cache[key] = LoadedBinary(device_idx, binary) - # pack arguments - fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)]) - params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)]) - # enqueue cached function into stream - callable = bin_cache[key] - stream = torch.cuda.current_stream(device_idx).cuda_stream - csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)} - grid = grid(csts) if hasattr(grid, '__call__') else grid - if isinstance(grid, int): - grid = tuple(grid) - callable(stream, params, *grid) - return callable + for pos, _type in self.fn.annotations.items(): + wargs[pos] = _type(wargs[pos]) + # query device index and cuda stream + device = torch.cuda.current_device() + torch.cuda.set_device(device) + cc = torch.cuda.get_device_capability(device) + cc = str(cc[0]) + '-' + str(cc[1]) + # # query stream + # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` + # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 + # # building a C wrapper to re-use the unpack function would add a build-time torch dependency + # # and require different wheels for different torch versions -- undesirable! + # bits = torch._C._cuda_getCurrentStream(device) + # mask = 1 << 47 + # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask + stream = torch.cuda.current_stream(device).cuda_stream + # make key for cache + return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, + self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) class Launcher: