From 0735061fce5bac3ec276cee55efd07c747d47a80 Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Thu, 23 Sep 2021 21:23:59 -0700 Subject: [PATCH] [FRONTEND] fix for unpickleable keys (#307) In #306, I added the key to the cache data, so we can introspect to investigate cache misses. Unfortunately, the key isn't pickleable, so just add the str version instead. Co-authored-by: hauntsaninja <> --- python/triton/code_gen.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 93b6d4681..7c7cda673 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -602,13 +602,13 @@ class Kernel: self.fn.cache_key, version_key(), compute_capability, types_key, attr_key, num_warps, num_stages, meta_key, const_key ) - key_str = repr(key) + key = repr(key) # get cached binary drv_cache = self.fn.drv_cache - if key_str not in drv_cache: - hashed_key = hashlib.md5(key_str.encode("utf-8")).hexdigest() + if key not in drv_cache: + hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() # create cache directory cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') @@ -641,12 +641,12 @@ class Kernel: pickle.dump({"binary": binary, "key": key}, f) os.rename(bin_cache_path + ".tmp", bin_cache_path) - drv_cache[key_str] = LoadedBinary(device_idx, binary) + drv_cache[key] = LoadedBinary(device_idx, binary) # pack arguments fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) params = struct.pack(fmt, *args) # enqueue cached function into stream - callable = drv_cache[key_str] + callable = drv_cache[key] stream = torch.cuda.current_stream(device_idx).cuda_stream grid = grid(meta) if hasattr(grid, '__call__') else grid callable(stream, params, *grid)