diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 5c1a1ccda..594dbc907 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -16,7 +16,8 @@ import setuptools import triton.frameworks as fw import triton.utils import triton._C.libtriton as libtriton - +import os +import time @contextlib.contextmanager def quiet(): @@ -131,7 +132,7 @@ def _make_framework_op(arg_types): so = os.path.join(root, f'op{suffix}') cpp = os.path.join(root, f'op.cpp') # handle cached .so file - if os.path.exists(so): + if os.path.exists(so) and os.stat(so).st_size > 0: tt_mtime = os.stat(os.path.realpath(libtriton.__file__)).st_mtime so_mtime = os.stat(so).st_mtime # can use cached if libtriton is older than the .so @@ -139,14 +140,31 @@ def _make_framework_op(arg_types): fw.torch.ops.load_library(so) return getattr(fw.torch.ops.triton, name) # create torch source code - src, _ = libtriton.make_torch_src(name, arg_types) - with open(cpp, 'w+') as handle: + lock = os.path.join(root, f'lock') + try: + fd = os.open(lock, os.O_CREAT|os.O_EXCL|os.O_RDWR) + if os.path.exists(so): + fw.torch.ops.load_library(so) + os.remove(lock) + return getattr(fw.torch.ops.triton, name) + src, _ = libtriton.make_torch_src(name, arg_types) + with open(cpp, 'w+') as handle: handle.writelines(src) - # compile torch source code - _build(cpp, root, 'op') - fw.torch.ops.load_library(so) - return getattr(fw.torch.ops.triton, name) + # create torch.so + _build(cpp, root, 'op') + fw.torch.ops.load_library(so) + os.remove(lock) + return getattr(fw.torch.ops.triton, name) + except FileExistsError: + # spin until .so is fully written + while os.path.exists(lock): + time.sleep(0.01) + fw.torch.ops.load_library(so) + return getattr(fw.torch.ops.triton, name) + + + class kernel: