[PYTHON][KERNEL] Added thread-safety when caching custom torch op

This commit is contained in:
Philippe Tillet
2020-04-07 20:21:50 -04:00
committed by Philippe Tillet
parent 677ccfb44e
commit 4ae0e28b32

View File

@@ -16,7 +16,8 @@ import setuptools
import triton.frameworks as fw
import triton.utils
import triton._C.libtriton as libtriton
import os
import time
@contextlib.contextmanager
def quiet():
@@ -131,7 +132,7 @@ def _make_framework_op(arg_types):
so = os.path.join(root, f'op{suffix}')
cpp = os.path.join(root, f'op.cpp')
# handle cached .so file
if os.path.exists(so):
if os.path.exists(so) and os.stat(so).st_size > 0:
tt_mtime = os.stat(os.path.realpath(libtriton.__file__)).st_mtime
so_mtime = os.stat(so).st_mtime
# can use cached if libtriton is older than the .so
@@ -139,14 +140,31 @@ def _make_framework_op(arg_types):
fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name)
# create torch source code
src, _ = libtriton.make_torch_src(name, arg_types)
with open(cpp, 'w+') as handle:
lock = os.path.join(root, f'lock')
try:
fd = os.open(lock, os.O_CREAT|os.O_EXCL|os.O_RDWR)
if os.path.exists(so):
fw.torch.ops.load_library(so)
os.remove(lock)
return getattr(fw.torch.ops.triton, name)
src, _ = libtriton.make_torch_src(name, arg_types)
with open(cpp, 'w+') as handle:
handle.writelines(src)
# compile torch source code
_build(cpp, root, 'op')
fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name)
# create torch.so
_build(cpp, root, 'op')
fw.torch.ops.load_library(so)
os.remove(lock)
return getattr(fw.torch.ops.triton, name)
except FileExistsError:
# spin until .so is fully written
while os.path.exists(lock):
time.sleep(0.01)
fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name)
class kernel: