[FRONTEND] Better cache hook (#400)
Added an additional `repr` argument to the cache hook, which represents a human-readable string representation of the signature and argument attributes associated with the compiled binary.
This commit is contained in:
@@ -700,7 +700,17 @@ class Kernel:
|
||||
pickle.dump({"binary": binary, "key": key}, f)
|
||||
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||
if JITFunction.cache_hook is not None:
|
||||
JITFunction.cache_hook(key=key, binary=binary)
|
||||
name = self.fn.fn.__name__
|
||||
info = key.split('-')[-3:]
|
||||
num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:]
|
||||
# make signature human-readable
|
||||
arg_reprs = []
|
||||
for arg_name, arg_sig in zip(self.fn.arg_names, sig):
|
||||
arg_reprs.append(f'{arg_name}: {arg_sig}')
|
||||
# assemble the repr
|
||||
arg_reprs = ", ".join(arg_reprs)
|
||||
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
|
||||
JITFunction.cache_hook(key=key, binary=binary, repr=repr)
|
||||
|
||||
self.fn.bin_cache[key] = LoadedBinary(device_idx, binary)
|
||||
|
||||
|
Reference in New Issue
Block a user