[FRONTEND] fix for unpickleable keys (#307)

In #306, I added the key to the cache data, so we can introspect to
investigate cache misses. Unfortunately, the key isn't pickleable,
so just add the str version instead.

Co-authored-by: hauntsaninja <>
This commit is contained in:
Shantanu
2021-09-23 21:23:59 -07:00
committed by GitHub
parent 2066ccd87e
commit 0735061fce

View File

@@ -602,13 +602,13 @@ class Kernel:
self.fn.cache_key, version_key(), compute_capability, self.fn.cache_key, version_key(), compute_capability,
types_key, attr_key, num_warps, num_stages, meta_key, const_key types_key, attr_key, num_warps, num_stages, meta_key, const_key
) )
key_str = repr(key) key = repr(key)
# get cached binary # get cached binary
drv_cache = self.fn.drv_cache drv_cache = self.fn.drv_cache
if key_str not in drv_cache: if key not in drv_cache:
hashed_key = hashlib.md5(key_str.encode("utf-8")).hexdigest() hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
# create cache directory # create cache directory
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
@@ -641,12 +641,12 @@ class Kernel:
pickle.dump({"binary": binary, "key": key}, f) pickle.dump({"binary": binary, "key": key}, f)
os.rename(bin_cache_path + ".tmp", bin_cache_path) os.rename(bin_cache_path + ".tmp", bin_cache_path)
drv_cache[key_str] = LoadedBinary(device_idx, binary) drv_cache[key] = LoadedBinary(device_idx, binary)
# pack arguments # pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
params = struct.pack(fmt, *args) params = struct.pack(fmt, *args)
# enqueue cached function into stream # enqueue cached function into stream
callable = drv_cache[key_str] callable = drv_cache[key]
stream = torch.cuda.current_stream(device_idx).cuda_stream stream = torch.cuda.current_stream(device_idx).cuda_stream
grid = grid(meta) if hasattr(grid, '__call__') else grid grid = grid(meta) if hasattr(grid, '__call__') else grid
callable(stream, params, *grid) callable(stream, params, *grid)