[FRONTEND] Ensure version_key is called at most once (#519)

Co-authored-by: hauntsaninja <>
This commit is contained in:
Shantanu
2022-05-23 13:40:08 -07:00
committed by GitHub
parent 205a493b10
commit 80f6a2698b

View File

@@ -11,6 +11,7 @@ import subprocess
import sys import sys
import tempfile import tempfile
import textwrap import textwrap
import threading
import time import time
import warnings import warnings
from typing import Dict, Set, Tuple, Union from typing import Dict, Set, Tuple, Union
@@ -1058,27 +1059,40 @@ class Autotuner:
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
@functools.lru_cache() _version_key_lock = threading.Lock()
_version_key = None
def version_key(): def version_key():
import pkgutil global _version_key
contents = []
# frontend if _version_key is not None:
with open(triton.code_gen.__file__, "rb") as f: return _version_key
contents += [hashlib.md5(f.read()).hexdigest()]
# backend with _version_key_lock:
with open(triton._C.libtriton.__file__, "rb") as f: if _version_key is not None:
contents += [hashlib.md5(f.read()).hexdigest()] return _version_key
# language
language_path = os.path.join(*triton.__path__, 'language') import pkgutil
for lib in pkgutil.iter_modules([language_path]): contents = []
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: # frontend
with open(triton.code_gen.__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()] contents += [hashlib.md5(f.read()).hexdigest()]
# ptxas version # backend
try: with open(triton._C.libtriton.__file__, "rb") as f:
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() contents += [hashlib.md5(f.read()).hexdigest()]
except Exception: # language
ptxas_version = '' language_path = os.path.join(*triton.__path__, 'language')
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# ptxas version
try:
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
except Exception:
ptxas_version = ''
_version_key = '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
return _version_key
class DependenciesFinder(ast.NodeVisitor): class DependenciesFinder(ast.NodeVisitor):