diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 93b6d4681..7c7cda673 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -602,13 +602,13 @@ class Kernel: self.fn.cache_key, version_key(), compute_capability, types_key, attr_key, num_warps, num_stages, meta_key, const_key ) - key_str = repr(key) + key = repr(key) # get cached binary drv_cache = self.fn.drv_cache - if key_str not in drv_cache: - hashed_key = hashlib.md5(key_str.encode("utf-8")).hexdigest() + if key not in drv_cache: + hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() # create cache directory cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') @@ -641,12 +641,12 @@ class Kernel: pickle.dump({"binary": binary, "key": key}, f) os.rename(bin_cache_path + ".tmp", bin_cache_path) - drv_cache[key_str] = LoadedBinary(device_idx, binary) + drv_cache[key] = LoadedBinary(device_idx, binary) # pack arguments fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) params = struct.pack(fmt, *args) # enqueue cached function into stream - callable = drv_cache[key_str] + callable = drv_cache[key] stream = torch.cuda.current_stream(device_idx).cuda_stream grid = grid(meta) if hasattr(grid, '__call__') else grid callable(stream, params, *grid)