[PYTHON][KERNEL] Added thread-safety when caching custom torch op
This commit is contained in:
committed by
Philippe Tillet
parent
677ccfb44e
commit
4ae0e28b32
@@ -16,7 +16,8 @@ import setuptools
|
|||||||
import triton.frameworks as fw
|
import triton.frameworks as fw
|
||||||
import triton.utils
|
import triton.utils
|
||||||
import triton._C.libtriton as libtriton
|
import triton._C.libtriton as libtriton
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def quiet():
|
def quiet():
|
||||||
@@ -131,7 +132,7 @@ def _make_framework_op(arg_types):
|
|||||||
so = os.path.join(root, f'op{suffix}')
|
so = os.path.join(root, f'op{suffix}')
|
||||||
cpp = os.path.join(root, f'op.cpp')
|
cpp = os.path.join(root, f'op.cpp')
|
||||||
# handle cached .so file
|
# handle cached .so file
|
||||||
if os.path.exists(so):
|
if os.path.exists(so) and os.stat(so).st_size > 0:
|
||||||
tt_mtime = os.stat(os.path.realpath(libtriton.__file__)).st_mtime
|
tt_mtime = os.stat(os.path.realpath(libtriton.__file__)).st_mtime
|
||||||
so_mtime = os.stat(so).st_mtime
|
so_mtime = os.stat(so).st_mtime
|
||||||
# can use cached if libtriton is older than the .so
|
# can use cached if libtriton is older than the .so
|
||||||
@@ -139,13 +140,30 @@ def _make_framework_op(arg_types):
|
|||||||
fw.torch.ops.load_library(so)
|
fw.torch.ops.load_library(so)
|
||||||
return getattr(fw.torch.ops.triton, name)
|
return getattr(fw.torch.ops.triton, name)
|
||||||
# create torch source code
|
# create torch source code
|
||||||
|
lock = os.path.join(root, f'lock')
|
||||||
|
try:
|
||||||
|
fd = os.open(lock, os.O_CREAT|os.O_EXCL|os.O_RDWR)
|
||||||
|
if os.path.exists(so):
|
||||||
|
fw.torch.ops.load_library(so)
|
||||||
|
os.remove(lock)
|
||||||
|
return getattr(fw.torch.ops.triton, name)
|
||||||
src, _ = libtriton.make_torch_src(name, arg_types)
|
src, _ = libtriton.make_torch_src(name, arg_types)
|
||||||
with open(cpp, 'w+') as handle:
|
with open(cpp, 'w+') as handle:
|
||||||
handle.writelines(src)
|
handle.writelines(src)
|
||||||
# compile torch source code
|
# create torch.so
|
||||||
_build(cpp, root, 'op')
|
_build(cpp, root, 'op')
|
||||||
fw.torch.ops.load_library(so)
|
fw.torch.ops.load_library(so)
|
||||||
|
os.remove(lock)
|
||||||
return getattr(fw.torch.ops.triton, name)
|
return getattr(fw.torch.ops.triton, name)
|
||||||
|
except FileExistsError:
|
||||||
|
# spin until .so is fully written
|
||||||
|
while os.path.exists(lock):
|
||||||
|
time.sleep(0.01)
|
||||||
|
fw.torch.ops.load_library(so)
|
||||||
|
return getattr(fw.torch.ops.triton, name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class kernel:
|
class kernel:
|
||||||
|
Reference in New Issue
Block a user