[FRONTEND] Ensure version_key is called at most once (#519)

Co-authored-by: hauntsaninja <>
This commit is contained in:
Shantanu
2022-05-23 13:40:08 -07:00
committed by GitHub
parent 205a493b10
commit 80f6a2698b

View File

@@ -11,6 +11,7 @@ import subprocess
import sys
import tempfile
import textwrap
import threading
import time
import warnings
from typing import Dict, Set, Tuple, Union
@@ -1058,27 +1059,40 @@ class Autotuner:
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
@functools.lru_cache()
_version_key_lock = threading.Lock()
_version_key = None
def version_key():
import pkgutil
contents = []
# frontend
with open(triton.code_gen.__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# backend
with open(triton._C.libtriton.__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# language
language_path = os.path.join(*triton.__path__, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
global _version_key
if _version_key is not None:
return _version_key
with _version_key_lock:
if _version_key is not None:
return _version_key
import pkgutil
contents = []
# frontend
with open(triton.code_gen.__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# ptxas version
try:
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
except Exception:
ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
# backend
with open(triton._C.libtriton.__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# language
language_path = os.path.join(*triton.__path__, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# ptxas version
try:
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
except Exception:
ptxas_version = ''
_version_key = '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
return _version_key
class DependenciesFinder(ast.NodeVisitor):