diff --git a/python/setup.py b/python/setup.py index 21cabe182..17db76093 100644 --- a/python/setup.py +++ b/python/setup.py @@ -18,7 +18,7 @@ import tarfile def get_llvm(): # tries to find system LLVM - versions = ['-11.0', '-11', '-11-64'] + versions = ['-11.0', '-11', '-11-64'] supported = ['llvm-config{v}'.format(v=v) for v in versions] paths = [distutils.spawn.find_executable(cfg) for cfg in supported] paths = [p for p in paths if p is not None] @@ -127,7 +127,7 @@ setup( description="A language and compiler for custom Deep Learning operations", long_description="", packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"], - install_requires=["torch", "filelock"], + install_requires=["cmake", "torch", "filelock"], package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], diff --git a/python/src/triton.cc b/python/src/triton.cc index 9298f9db4..abe441b3a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -13,6 +13,7 @@ #include #include #include +#include "Python.h" #include #include #include "llvm/IR/Module.h" @@ -23,6 +24,7 @@ namespace py = pybind11; namespace ir = triton::ir; namespace drv = triton::driver; + /*****************************************************************************/ /* Python bindings for triton::driver */ /*****************************************************************************/ @@ -99,8 +101,113 @@ void hip_enqueue(uint64_t stream, uint64_t kernel, } +std::string pow2_divisor(long N){ + if(N % 16 == 0) return "16"; + if(N % 8 == 0) return "8"; + if(N % 4 == 0) return "4"; + if(N % 2 == 0) return "2"; + return "1"; +} + +// Launch +void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_names, + std::string& cache_key, std::string& params, size_t& params_size, PyObject* constants, + int num_warps, int num_stages) { + size_t len = PyList_Size(args.ptr()); + params.reserve(8*len); // 8 max bytes by argument + char* params_ptr = ¶ms[0]; + cache_key = func_key; + for(int i = 0; i < len; i++){ + auto arg_ptr = PyList_GetItem(args.ptr(), i); + auto arg = py::handle(arg_ptr); + // argument is `long` + if(PyLong_Check(arg_ptr)){ + int overflow; + long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow); + // long and int have different kernels + if(!overflow & (std::abs(value) <= 0xffffffff)){ + cache_key += 'I'; + params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); + std::memcpy(params_ptr, &value, 4); + params_ptr += 4; + } + else{ + cache_key += 'L'; + params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); + if(overflow){ + unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr); + std::memcpy(&value, &uvalue, 8); + } + std::memcpy(params_ptr, &value, 8); + params_ptr += 8; + } + // values equal to 1 are specialized + if(value == 1) + cache_key += '1'; + else + cache_key += 'x'; + // values divisible by small powers of 2 are specialized + cache_key += pow2_divisor(value); + continue; + } + // argument is `float` + if(PyFloat_Check(arg_ptr)){ + cache_key += "f"; + float value = PyFloat_AsDouble(arg_ptr); + params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); + std::memcpy(params_ptr, &value, 4); + params_ptr += 4; + continue; + } + // argument is `bool` + if(PyBool_Check(arg_ptr)){ + cache_key += "B"; + bool value = arg_ptr == Py_True ? true : false; + std::memcpy(params_ptr, &value, 1); + params_ptr += 1; + continue; + } + // argument is tensor + PyObject* data_ptr = PyObject_CallMethod(arg_ptr, "data_ptr", nullptr); + if(data_ptr){ + cache_key += "P"; + long value = PyLong_AsLong(data_ptr); + params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); + std::memcpy(params_ptr, &value, 8); + params_ptr += 8; + PyObject* dtype = PyObject_GetAttrString(arg_ptr, "dtype"); + PyObject* repr = PyObject_Repr(dtype); + const char* start = (const char*)PyUnicode_1BYTE_DATA(repr) + 6; // remove 'torch.' + size_t len = PyUnicode_GET_LENGTH(repr) - 6; + cache_key += std::string(start, len); + continue; + } + // argument is `constexpr` + PyObject* value = PyObject_GetAttrString(arg_ptr, "value"); + if(value){ + PyObject* name = PyList_GetItem(arg_names.ptr(), i); + PyDict_SetItem(constants, name, value); + PyObject* repr = PyObject_Repr(value); + const char* start = (const char*)PyUnicode_1BYTE_DATA(repr); + size_t len = PyUnicode_GET_LENGTH(repr); + cache_key += std::string(start, len); + continue; + } + assert(false); + } + cache_key += std::to_string(num_warps); + cache_key += std::to_string(num_stages); + params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]); +} + +// + void init_triton_runtime(py::module &&m) { + // m.def("current_stream", [](uint64_t device){ + // return (uint64_t)(c10::cuda::getCurrentCUDAStream(device).stream()); + // }); + // wrap backend_t py::enum_(m, "backend") .value("HOST", HOST) @@ -116,6 +223,51 @@ void init_triton_runtime(py::module &&m) { } ); + // cache key + m.def("launch", [](py::handle args, const std::string& func_key, py::list& arg_names, + py::handle device, py::handle stream, py::handle bin_cache, py::handle num_warps, py::handle num_stages, + py::handle add_to_cache, py::handle grid){ + // parse arguments to compute cache key, compile-time constants and packed kernel arguments + long _num_warps = PyLong_AsLong(num_warps.ptr()); + long _num_stages = PyLong_AsLong(num_stages.ptr()); + std::string cache_key; + std::string params; + size_t params_size; + PyObject* constants = PyDict_New(); + parse_args(args, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages); + // get cached binary + PyObject* key = PyUnicode_FromString(cache_key.c_str()); + PyObject* bin = nullptr; + if(!PyDict_Contains(bin_cache.ptr(), key)){ + add_to_cache(py::handle(key), args, device, num_warps, num_stages); + } + bin = PyDict_GetItem(bin_cache.ptr(), key); + // get grid + PyObject* grid_ptr = grid.ptr(); + if(!PySequence_Check(grid_ptr)){ + PyObject* grid_call = PyObject_GetAttrString(grid_ptr, "__call__"); + grid_ptr = PyObject_Call(grid_call, PyTuple_Pack(1, constants), nullptr); + } + int size = PySequence_Size(grid_ptr); + int grid_0 = PyLong_AsLong(PySequence_GetItem(grid_ptr, 0)); + int grid_1 = size < 2 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 1)); + int grid_2 = size < 3 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 2)); + // enqueue + uint64_t kernel = PyLong_AsLong(PyObject_GetAttrString(bin, "kernel")); + uint64_t shared_mem = PyLong_AsLong(PyObject_GetAttrString(bin, "shared_mem")); + // actually launch + void *config[] = { + CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(), + CU_LAUNCH_PARAM_BUFFER_SIZE, ¶ms_size, + CU_LAUNCH_PARAM_END + }; + uint64_t _stream = PyLong_AsLong(stream.ptr()); + drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, + _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, + nullptr, config); + return py::handle(bin); + }); + // query maximum shared memory m.def("max_shared_memory", [](backend_t backend, uint64_t device) { if (backend == HOST) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index f14f3b135..b8b9f8129 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -464,6 +464,7 @@ class LoadedBinary: self.module = module self.kernel = kernel self.device = device + self.shared_mem = bin.shared_mem def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): _triton.runtime.enqueue(self.bin.backend, stream, self.kernel, @@ -548,16 +549,6 @@ class Kernel: name = Kernel._type_name(obj) return type_map[name](context) - @staticmethod - def _types_key(*wargs, tensor_idxs): - # type inference - types_key = [None] * len(wargs) - for i, arg in enumerate(wargs): - prefix = 'P' if i in tensor_idxs else '' - suffix = Kernel._type_name(arg) if i in tensor_idxs else Kernel._type_name(arg) - types_key[i] = prefix + suffix - return tuple(types_key) - @staticmethod def pow2_divisor(N): if N % 16 == 0: return 16 @@ -599,6 +590,53 @@ class Kernel: raise OutOfResources(shared_mem, max_shared_memory, "shared memory") return Binary(backend, name, asm, shared_mem, num_warps) + def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): + tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] + # attributes + args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] + attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ + if isinstance(a, int) and i not in self.fn.do_not_specialize} + + # transforms ints whose value is one into constants for just-in-time compilation + constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} + constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) + hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() + + # create cache directory + cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') + if cache_dir and not os.path.exists(cache_dir): + os.makedirs(cache_dir, exist_ok=True) + + if cache_dir: + bin_cache_path = os.path.join(cache_dir, hashed_key) + bin_lock_path = bin_cache_path + ".lock" + else: + bin_cache_path = None + bin_lock_path = None + + binary = None + if bin_cache_path and os.path.exists(bin_cache_path): + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path, 'rb') as f: + binary = pickle.load(f)["binary"] + if binary is None: + binary = self._compile( + *wargs, device=device_idx, attributes=attributes, + num_warps=num_warps, num_stages=num_stages, + constants=constants, + ) + if bin_cache_path: + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path + ".tmp", "wb") as f: + pickle.dump({"binary": binary, "key": key}, f) + os.rename(bin_cache_path + ".tmp", bin_cache_path) + if JITFunction.cache_hook is not None: + JITFunction.cache_hook(key=key, binary=binary) + + self.fn.bin_cache[key] = LoadedBinary(device_idx, binary) + def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): # handle arguments passed by name kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} @@ -608,112 +646,21 @@ class Kernel: if len(wargs) != len(self.fn.arg_names): raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") # handle annotations - for name, type in self.fn.__annotations__.items(): - pos = self.fn.arg_names.index(name) - assert type == triton.language.core.constexpr - wargs[pos] = type(wargs[pos]) - # device inference - tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] - if len(tensor_idxs) == 0: - raise ValueError("No Tensor argument found.") - invalid_args = [] - device_ids = [] - for idx in tensor_idxs: - curr = wargs[idx] - if not curr.is_cuda: - invalid_args.append(idx) - else: - device_ids.append(curr.device.index) - if invalid_args: - raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) + - " Only CUDA is supported at the moment") - - device = torch.device('cuda', torch.cuda.current_device()) - device_idx = device.index - # if len(set(device_ids)) != 1 or device_ids[0] != device_idx: - # # try to enable P2P communication - # for arg_idx, dst_idx in zip(tensor_idxs, device_ids): - # if dst_idx != device_idx: - # try: - # _triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr()) - # except RuntimeError as e: - # raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}" - # .format(device_idx, dst_idx, str(e))) - - # enqueue kernel on the current device - torch.cuda.set_device(device_idx) - # attributes - args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] - attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ - if isinstance(a, int) and i not in self.fn.do_not_specialize} - - # transforms ints whose value is one into constants for just-in-time compilation - constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} - constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) - - # compute hash for caching this kernel - types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs) - attr_key = tuple(attributes.items()) - const_key = tuple(constants.items()) - compute_capability = torch.cuda.get_device_capability(device) - key = ( - self.fn.cache_key, version_key(), compute_capability, - types_key, attr_key, num_warps, num_stages, const_key - ) - key = repr(key) - - # get cached binary - drv_cache = self.fn.drv_cache - - if key not in drv_cache: - hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() - - # create cache directory - cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') - if cache_dir and not os.path.exists(cache_dir): - os.makedirs(cache_dir, exist_ok=True) - - if cache_dir: - bin_cache_path = os.path.join(cache_dir, hashed_key) - bin_lock_path = bin_cache_path + ".lock" - else: - bin_cache_path = None - bin_lock_path = None - - binary = None - if bin_cache_path and os.path.exists(bin_cache_path): - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path, 'rb') as f: - binary = pickle.load(f)["binary"] - if binary is None: - binary = self._compile( - *wargs, device=device_idx, attributes=attributes, - num_warps=num_warps, num_stages=num_stages, - constants=constants, - ) - if bin_cache_path: - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path + ".tmp", "wb") as f: - pickle.dump({"binary": binary, "key": key}, f) - os.rename(bin_cache_path + ".tmp", bin_cache_path) - if JITFunction.cache_hook is not None: - JITFunction.cache_hook(key=key, binary=binary) - - drv_cache[key] = LoadedBinary(device_idx, binary) - # pack arguments - fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)]) - params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)]) - # enqueue cached function into stream - callable = drv_cache[key] - stream = torch.cuda.current_stream(device_idx).cuda_stream - csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)} - grid = grid(csts) if hasattr(grid, '__call__') else grid - if isinstance(grid, int): - grid = tuple(grid) - callable(stream, params, *grid) - return callable + for pos, _type in self.fn.annotations.items(): + wargs[pos] = _type(wargs[pos]) + # query device index and cuda stream + device = torch.cuda.current_device() + # query stream + # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` + # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 + # building a C wrapper to re-use the unpack function would add a build-time torch dependency + # and require different wheels for different torch versions -- undesirable! + bits = torch._C._cuda_getCurrentStream(device) + mask = 1 << 47 + stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask + # make key for cache + return _triton.runtime.launch(wargs, self.fn.cache_key, self.fn.arg_names, device, stream, + self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) class Launcher: @@ -723,6 +670,7 @@ class Launcher: def __call__(self, *wargs, **kwargs): return self.kernel(*wargs, **kwargs, grid=self.grid) + class Autotuner: @@ -773,6 +721,11 @@ class Autotuner: return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) +@functools.lru_cache() +def compute_capability(): + device = torch.device('cuda', 0) + return '-'.join(map(str, torch.cuda.get_device_capability(device))) + @functools.lru_cache() def version_key(): import pkgutil @@ -784,22 +737,27 @@ def version_key(): with open(triton._C.libtriton.__file__, "rb") as f: contents += [hashlib.md5(f.read()).hexdigest()] # language - for lib in pkgutil.iter_modules(triton.language.__path__): + language_path = os.path.join(*triton.__path__, 'language') + for lib in pkgutil.iter_modules([language_path]): with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: contents += [hashlib.md5(f.read()).hexdigest()] # ptxas version try: ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() except Exception: - ptxas_version = None - return (triton.__version__, ptxas_version) + tuple(contents) + ptxas_version = '' + return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) class JITFunction: cache_hook = None def _set_cache_key(self): - self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version) + self.cache_key = hashlib.md5(self.src.encode("utf-8")).hexdigest() + self.cache_key += str(self.version) + self.cache_key += version_key() + self.cache_key += compute_capability() + self.cache_key = hashlib.md5(self.cache_key.encode("utf-8")).hexdigest() def __init__(self, fn, version=None, do_not_specialize=None): # information of wrapped function @@ -811,7 +769,7 @@ class JITFunction: self.do_not_specialize = [] if do_not_specialize is None else\ [self.arg_names.index(arg) for arg in do_not_specialize] # cache for callable driver objects (e.g. CUkernel) - self.drv_cache = dict() + self.bin_cache = dict() # cache for binaries (on-disk) self._set_cache_key() # JITFunction can be instantiated as kernel @@ -819,6 +777,7 @@ class JITFunction: self.kernel_decorators = [] self.kernel = None # annotations + self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()} self.__annotations__ = fn.__annotations__ # forward docs self.__doc__ = fn.__doc__ @@ -834,7 +793,7 @@ class JITFunction: assert isinstance(tree.body[0], ast.FunctionDef) return tree - def __call__(self, *args, generator: CodeGenerator, **meta): + def __call__(self, *args, generator: CodeGenerator): try: gscope = generator.gscope.copy() lscope = generator.lscope.copy() diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 5eed3b67f..7875a30f6 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -119,6 +119,9 @@ class block: self.shape = (1, ) if self.handle.type.is_block(): self.shape = self.handle.type.shape + self.numel = 1 + for s in self.shape: + self.numel *= s # Data-type wrapper self.dtype = block._init_dtype(self.handle.type.scalar) @@ -352,6 +355,13 @@ def program_id(axis, _builder=None): :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :type axis: int """ + # if axis == -1: + # pid0 = frontend.program_id(0, _builder) + # pid1 = frontend.program_id(1, _builder) + # pid2 = frontend.program_id(2, _builder) + # npg0 = frontend.num_programs(0, _builder) + # npg1 = frontend.num_programs(0, _builder) + # return pid0 + pid1*npg0 + pid2*npg0*npg1 return frontend.program_id(axis, _builder)