From d4d8eaf6c08d824b0e098c6438f3cc2230279477 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 15 Mar 2022 12:20:51 -0700 Subject: [PATCH] [FRONTEND] improved caching mechanism (#474) Co-authored-by: Greg Brockman Co-authored-by: Christopher Hesse --- python/src/triton.cc | 9 +- python/test/unit/runtime/test_cache.py | 4 +- python/triton/code_gen.py | 177 ++++++++++++++----------- 3 files changed, 106 insertions(+), 84 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index c5c5b196f..9e53cc341 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -299,8 +299,12 @@ void init_triton_runtime(py::module &&m) { // get cached binary py::str key(cache_key); - if(!bin_cache.contains(key)) - add_to_cache(key, args, device, num_warps, num_stages); + py::bool_ noop = false; + if(!bin_cache.contains(key)) { + noop = add_to_cache(key, args, device, num_warps, num_stages); + } + if (noop) + return (py::object)py::none(); py::object bin = bin_cache[key]; // get grid @@ -529,6 +533,7 @@ void init_triton_codegen(py::module &&m) { return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); }, py::return_value_policy::take_ownership); m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ + py::gil_scoped_release allow_threads; if(backend == CUDA) return cu_load_binary(name, asm_map, n_shared_bytes, dev); if(backend == ROCM) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 48797b51a..8ac01bcc8 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -76,7 +76,7 @@ def reset_tmp_dir(): def test_reuse(): counter = 0 - def inc_counter(key, binary, repr): + def inc_counter(*args, **kwargs): nonlocal counter counter += 1 JITFunction.cache_hook = inc_counter @@ -91,7 +91,7 @@ def test_reuse(): def test_specialize(mode): counter = 0 - def inc_counter(key, binary, repr): + def inc_counter(*args, **kwargs): nonlocal counter counter += 1 JITFunction.cache_hook = inc_counter diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 894b3f1e3..3f170098b 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -602,8 +602,19 @@ class Kernel: return 'str' raise NotImplementedError(f'could not compute type name for {obj}') + @staticmethod + def _to_python_ir(obj): + # convert torch.Tensor to Triton IR pointers + if hasattr(obj, 'data_ptr'): + name = Kernel._type_name(obj) + return 'ptr', name + # default path returns triton.ir.type directly + name = Kernel._type_name(obj) + return 'scalar', name + @staticmethod def _to_triton_ir(context, obj): + which, name = obj type_map = { 'I': _triton.ir.type.get_int32, 'L': _triton.ir.type.get_int64, @@ -625,12 +636,10 @@ class Kernel: 'u64': _triton.ir.type.get_uint64, } # convert torch.Tensor to Triton IR pointers - if hasattr(obj, 'data_ptr'): - name = Kernel._type_name(obj) + if which == 'ptr': elt_ty = type_map[name](context) return _triton.ir.type.make_ptr(elt_ty, 1) # default path returns triton.ir.type directly - name = Kernel._type_name(obj) return type_map[name](context) @staticmethod @@ -648,36 +657,6 @@ class Kernel: def __init__(self, fn): self.fn = fn - def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages): - # create IR module - context = _triton.ir.context() - # get just-in-time proto-type of kernel - fn_args = [arg for i, arg in enumerate(wargs) if i not in constants] - arg_types = [Kernel._to_triton_ir(context, arg) for arg in fn_args] - ret_type = _triton.ir.type.get_void(context) - prototype = _triton.ir.type.make_function(ret_type, arg_types) - # generate Triton-IR - # export symbols visible from self.fn into code-generator object - gscope = self.fn.__globals__ - generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) - try: - generator.visit(self.fn.parse()) - except Exception as e: - node = generator.last_node - if node is None or isinstance(e, (NotImplementedError, CompilationError)): - raise e - raise CompilationError(self.fn.src, node) from e - # Compile to machine code - if torch.version.hip is None: - backend = _triton.runtime.backend.CUDA - else: - backend = _triton.runtime.backend.ROCM - name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages) - max_shared_memory = _triton.runtime.max_shared_memory(backend, device) - if shared_mem > max_shared_memory: - raise OutOfResources(shared_mem, max_shared_memory, "shared memory") - return Binary(backend, name, asm, shared_mem, num_warps) - def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] # attributes @@ -692,57 +671,12 @@ class Kernel: range_size = _triton.runtime.get_pointer_range_size(addr) attributes[i] = min(Kernel.pow2_divisor(addr), Kernel.pow2_divisor(range_size)) - # transforms ints whose value is one into constants for just-in-time compilation constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize} constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) - hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() - - # create cache directory - cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') - if cache_dir and not os.path.exists(cache_dir): - os.makedirs(cache_dir, exist_ok=True) - - if cache_dir: - bin_cache_path = os.path.join(cache_dir, hashed_key) - bin_lock_path = bin_cache_path + ".lock" - else: - bin_cache_path = None - bin_lock_path = None - - binary = None - if bin_cache_path and os.path.exists(bin_cache_path): - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path, 'rb') as f: - binary = pickle.load(f)["binary"] - if binary is None: - binary = self._compile( - *wargs, device=device_idx, attributes=attributes, - num_warps=num_warps, num_stages=num_stages, - constants=constants, - ) - if bin_cache_path: - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path + ".tmp", "wb") as f: - pickle.dump({"binary": binary, "key": key}, f) - os.rename(bin_cache_path + ".tmp", bin_cache_path) - if JITFunction.cache_hook is not None: - name = self.fn.__name__ - info = key.split('-')[-3:] - num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:] - # make signature human-readable - arg_reprs = [] - for arg_name, arg_sig in zip(self.fn.arg_names, sig): - arg_reprs.append(f'{arg_name}: {arg_sig}') - # assemble the repr - arg_reprs = ", ".join(arg_reprs) - repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})" - JITFunction.cache_hook(key=key, binary=binary, repr=repr) - - self.fn.bin_cache[key] = LoadedBinary(device_idx, binary) + arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants] + return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False) def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): # handle arguments passed by name @@ -1027,6 +961,89 @@ class JITFunction: self.kernel = decorator(self.kernel) return self.kernel + def warmup(self, compile): + return self._warmup(**compile, is_manual_warmup=True) + + def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, is_manual_warmup): + hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() + + # create cache directory + cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') + if cache_dir: + os.makedirs(cache_dir, exist_ok=True) + + if cache_dir: + bin_cache_path = os.path.join(cache_dir, hashed_key) + bin_lock_path = bin_cache_path + ".lock" + else: + bin_cache_path = None + bin_lock_path = None + + binary = None + if bin_cache_path and os.path.exists(bin_cache_path): + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path, 'rb') as f: + binary = pickle.load(f)["binary"] + + compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages) + if JITFunction.cache_hook is not None: + name = self.__name__ + info = key.split('-')[-3:] + num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:] + # make signature human-readable + arg_reprs = [] + for arg_name, arg_sig in zip(self.arg_names, sig): + arg_reprs.append(f'{arg_name}: {arg_sig}') + # assemble the repr + arg_reprs = ", ".join(arg_reprs) + repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})" + noop = JITFunction.cache_hook(key=key, repr=repr, fn=self, compile={"key": key, **compile}, is_manual_warmup=is_manual_warmup, already_compiled=binary is not None) + if noop: + return True + + if binary is None: + binary = self._compile(**compile) + + if bin_cache_path: + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path + ".tmp", "wb") as f: + pickle.dump({"binary": binary, "key": key}, f) + os.rename(bin_cache_path + ".tmp", bin_cache_path) + + self.bin_cache[key] = LoadedBinary(device, binary) + return False + + def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages): + # create IR module + context = _triton.ir.context() + # get just-in-time proto-type of kernel + arg_types = [Kernel._to_triton_ir(context, arg) for arg in arg_types] + ret_type = _triton.ir.type.get_void(context) + prototype = _triton.ir.type.make_function(ret_type, arg_types) + # generate Triton-IR + # export symbols visible from self into code-generator object + gscope = self.__globals__ + generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) + try: + generator.visit(self.parse()) + except Exception as e: + node = generator.last_node + if node is None or isinstance(e, (NotImplementedError, CompilationError)): + raise e + raise CompilationError(self.src, node) from e + # Compile to machine code + if torch.version.hip is None: + backend = _triton.runtime.backend.CUDA + else: + backend = _triton.runtime.backend.ROCM + name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages) + max_shared_memory = _triton.runtime.max_shared_memory(backend, device) + if shared_mem > max_shared_memory: + raise OutOfResources(shared_mem, max_shared_memory, "shared memory") + return Binary(backend, name, asm, shared_mem, num_warps) + def __getitem__(self, grid): return Launcher(self._init_kernel(), grid)