[FRONTEND] fix for unpickleable keys (#307)
In #306, I added the key to the cache data, so we can introspect to investigate cache misses. Unfortunately, the key isn't pickleable, so just add the str version instead. Co-authored-by: hauntsaninja <>
This commit is contained in:
@@ -602,13 +602,13 @@ class Kernel:
|
|||||||
self.fn.cache_key, version_key(), compute_capability,
|
self.fn.cache_key, version_key(), compute_capability,
|
||||||
types_key, attr_key, num_warps, num_stages, meta_key, const_key
|
types_key, attr_key, num_warps, num_stages, meta_key, const_key
|
||||||
)
|
)
|
||||||
key_str = repr(key)
|
key = repr(key)
|
||||||
|
|
||||||
# get cached binary
|
# get cached binary
|
||||||
drv_cache = self.fn.drv_cache
|
drv_cache = self.fn.drv_cache
|
||||||
|
|
||||||
if key_str not in drv_cache:
|
if key not in drv_cache:
|
||||||
hashed_key = hashlib.md5(key_str.encode("utf-8")).hexdigest()
|
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
# create cache directory
|
# create cache directory
|
||||||
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
||||||
@@ -641,12 +641,12 @@ class Kernel:
|
|||||||
pickle.dump({"binary": binary, "key": key}, f)
|
pickle.dump({"binary": binary, "key": key}, f)
|
||||||
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||||
|
|
||||||
drv_cache[key_str] = LoadedBinary(device_idx, binary)
|
drv_cache[key] = LoadedBinary(device_idx, binary)
|
||||||
# pack arguments
|
# pack arguments
|
||||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||||
params = struct.pack(fmt, *args)
|
params = struct.pack(fmt, *args)
|
||||||
# enqueue cached function into stream
|
# enqueue cached function into stream
|
||||||
callable = drv_cache[key_str]
|
callable = drv_cache[key]
|
||||||
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
||||||
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
||||||
callable(stream, params, *grid)
|
callable(stream, params, *grid)
|
||||||
|
Reference in New Issue
Block a user