From 80f6a2698b56997f25d4be0a10191eb196fcb021 Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Mon, 23 May 2022 13:40:08 -0700 Subject: [PATCH] [FRONTEND] Ensure version_key is called at most once (#519) Co-authored-by: hauntsaninja <> --- python/triton/code_gen.py | 52 +++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 35c097017..82ace0105 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -11,6 +11,7 @@ import subprocess import sys import tempfile import textwrap +import threading import time import warnings from typing import Dict, Set, Tuple, Union @@ -1058,27 +1059,40 @@ class Autotuner: return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) -@functools.lru_cache() +_version_key_lock = threading.Lock() +_version_key = None + + def version_key(): - import pkgutil - contents = [] - # frontend - with open(triton.code_gen.__file__, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] - # backend - with open(triton._C.libtriton.__file__, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] - # language - language_path = os.path.join(*triton.__path__, 'language') - for lib in pkgutil.iter_modules([language_path]): - with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + global _version_key + + if _version_key is not None: + return _version_key + + with _version_key_lock: + if _version_key is not None: + return _version_key + + import pkgutil + contents = [] + # frontend + with open(triton.code_gen.__file__, "rb") as f: contents += [hashlib.md5(f.read()).hexdigest()] - # ptxas version - try: - ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() - except Exception: - ptxas_version = '' - return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) + # backend + with open(triton._C.libtriton.__file__, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # language + language_path = os.path.join(*triton.__path__, 'language') + for lib in pkgutil.iter_modules([language_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # ptxas version + try: + ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() + except Exception: + ptxas_version = '' + _version_key = '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) + return _version_key class DependenciesFinder(ast.NodeVisitor):