[FRONTEND] Added simple hook for when something is written to the cache (#308)

This commit is contained in:
Philippe Tillet
2021-09-23 22:23:17 -07:00
committed by GitHub
parent 0735061fce
commit 83da3febf2
2 changed files with 7 additions and 2 deletions

View File

@@ -5,7 +5,7 @@ __version__ = '1.0.1'
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
# submodules
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, Config, Autotuner, reinterpret
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret
from . import language
from . import code_gen

View File

@@ -640,6 +640,8 @@ class Kernel:
with open(bin_cache_path + ".tmp", "wb") as f:
pickle.dump({"binary": binary, "key": key}, f)
os.rename(bin_cache_path + ".tmp", bin_cache_path)
if JITFunction.cache_hook is not None:
JITFunction.cache_hook(key=key, binary=binary)
drv_cache[key] = LoadedBinary(device_idx, binary)
# pack arguments
@@ -732,6 +734,9 @@ def version_key():
)
class JITFunction:
cache_hook = None
def _set_cache_key(self):
self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version)
@@ -744,8 +749,8 @@ class JITFunction:
self.src = textwrap.dedent(inspect.getsource(fn))
# cache for callable driver objects (e.g. CUkernel)
self.drv_cache = dict()
# cache for binaries (on-disk)
self._set_cache_key()
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
self.kernel_decorators = []