From 955b027103c48af59901b24bf5df14d3e09ce25a Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 24 Jun 2020 13:51:46 -0400 Subject: [PATCH] [TRITON][KERNEL] Fixed issue for concurrent compilation of torch extensions --- python/triton/kernel.py | 96 +++++++++++++++++++++++++++++++++-------- 1 file changed, 77 insertions(+), 19 deletions(-) diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 0b4318e31..21c0c4a34 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -8,6 +8,7 @@ import sys import weakref import contextlib import io +import torch.utils.cpp_extension # import for just-in-time compilation import distutils import setuptools.command.build_ext @@ -18,9 +19,54 @@ import triton.utils import triton._C.libtriton as libtriton import os import time -import torch.utils.cpp_extension import platform +@contextlib.contextmanager +def quiet(): + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = io.StringIO(), io.StringIO() + try: + yield + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + +def _build(src, path, name): + ccdir = os.path.join(libtriton.__file__, os.path.pardir) + ccdir = os.path.realpath(ccdir) + # include / libraries + include_dirs = [os.path.join(ccdir, 'include')] + library_dirs = [ccdir] + libraries = ['triton'] + # create extension module + abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI + extra_compile_args = ['-fPIC', '-Wno-deprecated-declarations', '-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)] + extra_compile_args += ['-DTORCH_EXTENSION_NAME={}'.format(name)] + extra_compile_args += ['-DTORCH_API_INCLUDE_EXTENSION_H'] + + ext = torch.utils.cpp_extension.CUDAExtension( + name = name, + language = 'c++', + sources = [src], + include_dirs = include_dirs, + library_dirs = library_dirs, + libraries = libraries, + extra_compile_args = extra_compile_args, + depends = [os.path.realpath(libtriton.__file__)] + ) + # build extension module + args = ['build_ext'] + tmp = tempfile.mkdtemp() + args.append('--build-temp=' + tmp) + args.append('--build-lib=' + path) + args.append('-q') + args = dict( + name = name, + ext_modules = [ext], + script_args = args, + ) + with quiet(): + setuptools.setup(**args) + shutil.rmtree(tmp) def _cvt_to_def_str(obj): # bool @@ -77,26 +123,38 @@ def _make_framework_op(arg_types): fw.torch.ops.load_library(so) return getattr(fw.torch.ops.triton, name) # create torch source code - lock = os.path.join(root, f'lock') - src, _ = libtriton.make_torch_src(name, arg_types) - with open(cpp, 'w+') as handle: - handle.writelines(src) - # create torch.so - src, _ = libtriton.make_torch_src(name, arg_types) - ccdir = os.path.join(libtriton.__file__, os.path.pardir) - ccdir = os.path.realpath(ccdir) print('[TRITON] Compiling op...') - machine = platform.machine() - extra_cflags = ['-std=gnu++11'] if machine == 'ppc64le' else None - lib = torch.utils.cpp_extension.load(name, cpp, - extra_ldflags = [f'-L{ccdir}', '-ltriton'], - extra_include_paths = [os.path.join(ccdir, 'include')], - extra_cflags = extra_cflags, - build_directory = root, - is_python_module = False, - with_cuda = True) + baton = torch.utils.file_baton.FileBaton(os.path.join(root, 'lock')) + if baton.try_acquire(): + try: + src, _ = libtriton.make_torch_src(name, arg_types) + with open(cpp, 'w+') as handle: + handle.writelines(src) + ccdir = os.path.join(libtriton.__file__, os.path.pardir) + ccdir = os.path.realpath(ccdir) + #include_dirs = [os.path.join(ccdir, 'include')] + #library_dirs = [ccdir] + #_build(cpp, root, 'op') + #libraries = ['triton'] + machine = platform.machine() + torch.utils.cpp_extension._write_ninja_file_and_build( + name=name, + sources=[cpp], + extra_cflags=['-std=gnu++11'] if machine == 'ppc64le' else [], + extra_cuda_cflags=[], + extra_ldflags=[f'-L{ccdir}', '-ltriton'], + extra_include_paths=[os.path.join(ccdir, 'include')], + build_directory=root, + verbose=False, + with_cuda=True) + finally: + baton.release() + else: + baton.wait() + print('[TRITON] Done compiling...') fw.torch.ops.load_library(so) return getattr(fw.torch.ops.triton, name) + @@ -149,4 +207,4 @@ class kernel: # launch self.fw_op(self.op_id, device, bench, bench_id, *args) if bench > 0: - return libtriton.retrieve_scalar(bench_id) + return libtriton.retrieve_scalar(bench_id) \ No newline at end of file