diff --git a/python/triton/__init__.py b/python/triton/__init__.py index ea407e834..3abe85b0c 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -5,7 +5,7 @@ __version__ = '1.0.1' # or pybind11 shows `munmap_chunk(): invalid pointer` import torch # submodules -from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, Config, Autotuner, reinterpret +from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret from . import language from . import code_gen diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 7c7cda673..aedb2051f 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -640,6 +640,8 @@ class Kernel: with open(bin_cache_path + ".tmp", "wb") as f: pickle.dump({"binary": binary, "key": key}, f) os.rename(bin_cache_path + ".tmp", bin_cache_path) + if JITFunction.cache_hook is not None: + JITFunction.cache_hook(key=key, binary=binary) drv_cache[key] = LoadedBinary(device_idx, binary) # pack arguments @@ -732,6 +734,9 @@ def version_key(): ) class JITFunction: + + cache_hook = None + def _set_cache_key(self): self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version) @@ -744,8 +749,8 @@ class JITFunction: self.src = textwrap.dedent(inspect.getsource(fn)) # cache for callable driver objects (e.g. CUkernel) self.drv_cache = dict() + # cache for binaries (on-disk) self._set_cache_key() - # JITFunction can be instantiated as kernel # when called with a grid using __getitem__ self.kernel_decorators = []