[FRONTEND] Added simple hook for when something is written to the cache (#308)
This commit is contained in:
@@ -5,7 +5,7 @@ __version__ = '1.0.1'
|
|||||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||||
import torch
|
import torch
|
||||||
# submodules
|
# submodules
|
||||||
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, Config, Autotuner, reinterpret
|
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret
|
||||||
|
|
||||||
from . import language
|
from . import language
|
||||||
from . import code_gen
|
from . import code_gen
|
||||||
|
@@ -640,6 +640,8 @@ class Kernel:
|
|||||||
with open(bin_cache_path + ".tmp", "wb") as f:
|
with open(bin_cache_path + ".tmp", "wb") as f:
|
||||||
pickle.dump({"binary": binary, "key": key}, f)
|
pickle.dump({"binary": binary, "key": key}, f)
|
||||||
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||||
|
if JITFunction.cache_hook is not None:
|
||||||
|
JITFunction.cache_hook(key=key, binary=binary)
|
||||||
|
|
||||||
drv_cache[key] = LoadedBinary(device_idx, binary)
|
drv_cache[key] = LoadedBinary(device_idx, binary)
|
||||||
# pack arguments
|
# pack arguments
|
||||||
@@ -732,6 +734,9 @@ def version_key():
|
|||||||
)
|
)
|
||||||
|
|
||||||
class JITFunction:
|
class JITFunction:
|
||||||
|
|
||||||
|
cache_hook = None
|
||||||
|
|
||||||
def _set_cache_key(self):
|
def _set_cache_key(self):
|
||||||
self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version)
|
self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version)
|
||||||
|
|
||||||
@@ -744,8 +749,8 @@ class JITFunction:
|
|||||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||||
# cache for callable driver objects (e.g. CUkernel)
|
# cache for callable driver objects (e.g. CUkernel)
|
||||||
self.drv_cache = dict()
|
self.drv_cache = dict()
|
||||||
|
# cache for binaries (on-disk)
|
||||||
self._set_cache_key()
|
self._set_cache_key()
|
||||||
|
|
||||||
# JITFunction can be instantiated as kernel
|
# JITFunction can be instantiated as kernel
|
||||||
# when called with a grid using __getitem__
|
# when called with a grid using __getitem__
|
||||||
self.kernel_decorators = []
|
self.kernel_decorators = []
|
||||||
|
Reference in New Issue
Block a user