[PYTHON][KERNEL] Added thread-safety when caching custom torch op
This commit is contained in:
committed by
Philippe Tillet
parent
677ccfb44e
commit
4ae0e28b32
@@ -16,7 +16,8 @@ import setuptools
|
||||
import triton.frameworks as fw
|
||||
import triton.utils
|
||||
import triton._C.libtriton as libtriton
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
@contextlib.contextmanager
|
||||
def quiet():
|
||||
@@ -131,7 +132,7 @@ def _make_framework_op(arg_types):
|
||||
so = os.path.join(root, f'op{suffix}')
|
||||
cpp = os.path.join(root, f'op.cpp')
|
||||
# handle cached .so file
|
||||
if os.path.exists(so):
|
||||
if os.path.exists(so) and os.stat(so).st_size > 0:
|
||||
tt_mtime = os.stat(os.path.realpath(libtriton.__file__)).st_mtime
|
||||
so_mtime = os.stat(so).st_mtime
|
||||
# can use cached if libtriton is older than the .so
|
||||
@@ -139,14 +140,31 @@ def _make_framework_op(arg_types):
|
||||
fw.torch.ops.load_library(so)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
# create torch source code
|
||||
src, _ = libtriton.make_torch_src(name, arg_types)
|
||||
with open(cpp, 'w+') as handle:
|
||||
lock = os.path.join(root, f'lock')
|
||||
try:
|
||||
fd = os.open(lock, os.O_CREAT|os.O_EXCL|os.O_RDWR)
|
||||
if os.path.exists(so):
|
||||
fw.torch.ops.load_library(so)
|
||||
os.remove(lock)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
src, _ = libtriton.make_torch_src(name, arg_types)
|
||||
with open(cpp, 'w+') as handle:
|
||||
handle.writelines(src)
|
||||
# compile torch source code
|
||||
_build(cpp, root, 'op')
|
||||
fw.torch.ops.load_library(so)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
# create torch.so
|
||||
_build(cpp, root, 'op')
|
||||
fw.torch.ops.load_library(so)
|
||||
os.remove(lock)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
except FileExistsError:
|
||||
# spin until .so is fully written
|
||||
while os.path.exists(lock):
|
||||
time.sleep(0.01)
|
||||
fw.torch.ops.load_library(so)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class kernel:
|
||||
|
||||
|
Reference in New Issue
Block a user