[PYTHON][KERNEL] Added thread-safety when caching custom torch op

This commit is contained in:
Philippe Tillet
2020-04-07 20:21:50 -04:00
committed by Philippe Tillet
parent 677ccfb44e
commit 4ae0e28b32

View File

@@ -16,7 +16,8 @@ import setuptools
import triton.frameworks as fw import triton.frameworks as fw
import triton.utils import triton.utils
import triton._C.libtriton as libtriton import triton._C.libtriton as libtriton
import os
import time
@contextlib.contextmanager @contextlib.contextmanager
def quiet(): def quiet():
@@ -131,7 +132,7 @@ def _make_framework_op(arg_types):
so = os.path.join(root, f'op{suffix}') so = os.path.join(root, f'op{suffix}')
cpp = os.path.join(root, f'op.cpp') cpp = os.path.join(root, f'op.cpp')
# handle cached .so file # handle cached .so file
if os.path.exists(so): if os.path.exists(so) and os.stat(so).st_size > 0:
tt_mtime = os.stat(os.path.realpath(libtriton.__file__)).st_mtime tt_mtime = os.stat(os.path.realpath(libtriton.__file__)).st_mtime
so_mtime = os.stat(so).st_mtime so_mtime = os.stat(so).st_mtime
# can use cached if libtriton is older than the .so # can use cached if libtriton is older than the .so
@@ -139,13 +140,30 @@ def _make_framework_op(arg_types):
fw.torch.ops.load_library(so) fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name) return getattr(fw.torch.ops.triton, name)
# create torch source code # create torch source code
lock = os.path.join(root, f'lock')
try:
fd = os.open(lock, os.O_CREAT|os.O_EXCL|os.O_RDWR)
if os.path.exists(so):
fw.torch.ops.load_library(so)
os.remove(lock)
return getattr(fw.torch.ops.triton, name)
src, _ = libtriton.make_torch_src(name, arg_types) src, _ = libtriton.make_torch_src(name, arg_types)
with open(cpp, 'w+') as handle: with open(cpp, 'w+') as handle:
handle.writelines(src) handle.writelines(src)
# compile torch source code # create torch.so
_build(cpp, root, 'op') _build(cpp, root, 'op')
fw.torch.ops.load_library(so) fw.torch.ops.load_library(so)
os.remove(lock)
return getattr(fw.torch.ops.triton, name) return getattr(fw.torch.ops.triton, name)
except FileExistsError:
# spin until .so is fully written
while os.path.exists(lock):
time.sleep(0.01)
fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name)
class kernel: class kernel: