[FRONTEND] fix for unpickleable keys (#307)
In #306, I added the key to the cache data, so we can introspect to investigate cache misses. Unfortunately, the key isn't pickleable, so just add the str version instead. Co-authored-by: hauntsaninja <>
This commit is contained in:
@@ -602,13 +602,13 @@ class Kernel:
|
||||
self.fn.cache_key, version_key(), compute_capability,
|
||||
types_key, attr_key, num_warps, num_stages, meta_key, const_key
|
||||
)
|
||||
key_str = repr(key)
|
||||
key = repr(key)
|
||||
|
||||
# get cached binary
|
||||
drv_cache = self.fn.drv_cache
|
||||
|
||||
if key_str not in drv_cache:
|
||||
hashed_key = hashlib.md5(key_str.encode("utf-8")).hexdigest()
|
||||
if key not in drv_cache:
|
||||
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
|
||||
# create cache directory
|
||||
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
||||
@@ -641,12 +641,12 @@ class Kernel:
|
||||
pickle.dump({"binary": binary, "key": key}, f)
|
||||
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||
|
||||
drv_cache[key_str] = LoadedBinary(device_idx, binary)
|
||||
drv_cache[key] = LoadedBinary(device_idx, binary)
|
||||
# pack arguments
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||
params = struct.pack(fmt, *args)
|
||||
# enqueue cached function into stream
|
||||
callable = drv_cache[key_str]
|
||||
callable = drv_cache[key]
|
||||
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
||||
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
||||
callable(stream, params, *grid)
|
||||
|
Reference in New Issue
Block a user