diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 7f6982329..40c3ddee4 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -726,24 +726,24 @@ class Autotuner: @functools.lru_cache() def version_key(): + import pkgutil + contents = [] + # frontend with open(triton.code_gen.__file__, "rb") as f: - frontend_contents = hashlib.md5(f.read()).hexdigest() + contents += [hashlib.md5(f.read()).hexdigest()] + # backend with open(triton._C.libtriton.__file__, "rb") as f: - backend_contents = hashlib.md5(f.read()).hexdigest() - - try: - nvcc_version = hashlib.md5(subprocess.check_output(["nvcc", "--version"])).hexdigest() - except Exception: - nvcc_version = None + contents += [hashlib.md5(f.read()).hexdigest()] + # language + for lib in pkgutil.iter_modules(triton.language.__path__): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # ptxas version try: ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() except Exception: ptxas_version = None - - return ( - triton.__version__, frontend_contents, backend_contents, - nvcc_version, ptxas_version - ) + return (triton.__version__, ptxas_version) + tuple(contents) class JITFunction: