From 4ae0e28b326a4eaf115f8d035bb757b3411c6d12 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 7 Apr 2020 20:21:50 -0400 Subject: [PATCH] [PYTHON][KERNEL] Added thread-safety when caching custom torch op --- python/triton/kernel.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 5c1a1ccda..594dbc907 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -16,7 +16,8 @@ import setuptools import triton.frameworks as fw import triton.utils import triton._C.libtriton as libtriton - +import os +import time @contextlib.contextmanager def quiet(): @@ -131,7 +132,7 @@ def _make_framework_op(arg_types): so = os.path.join(root, f'op{suffix}') cpp = os.path.join(root, f'op.cpp') # handle cached .so file - if os.path.exists(so): + if os.path.exists(so) and os.stat(so).st_size > 0: tt_mtime = os.stat(os.path.realpath(libtriton.__file__)).st_mtime so_mtime = os.stat(so).st_mtime # can use cached if libtriton is older than the .so @@ -139,14 +140,31 @@ def _make_framework_op(arg_types): fw.torch.ops.load_library(so) return getattr(fw.torch.ops.triton, name) # create torch source code - src, _ = libtriton.make_torch_src(name, arg_types) - with open(cpp, 'w+') as handle: + lock = os.path.join(root, f'lock') + try: + fd = os.open(lock, os.O_CREAT|os.O_EXCL|os.O_RDWR) + if os.path.exists(so): + fw.torch.ops.load_library(so) + os.remove(lock) + return getattr(fw.torch.ops.triton, name) + src, _ = libtriton.make_torch_src(name, arg_types) + with open(cpp, 'w+') as handle: handle.writelines(src) - # compile torch source code - _build(cpp, root, 'op') - fw.torch.ops.load_library(so) - return getattr(fw.torch.ops.triton, name) + # create torch.so + _build(cpp, root, 'op') + fw.torch.ops.load_library(so) + os.remove(lock) + return getattr(fw.torch.ops.triton, name) + except FileExistsError: + # spin until .so is fully written + while os.path.exists(lock): + time.sleep(0.01) + fw.torch.ops.load_library(so) + return getattr(fw.torch.ops.triton, name) + + + class kernel: