From 04a9ea060b35789546294d1c9b194dbdbd1ad800 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 8 May 2020 23:26:38 -0400 Subject: [PATCH] [GENERAL] Added compatibility with pytorch 1.2.0 and powerpc --- CMakeLists.txt | 2 +- python/triton/kernel.py | 106 +++++++--------------------------------- 2 files changed, 20 insertions(+), 88 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index cef253883..3875cf348 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,7 +21,7 @@ endif() # Compiler flags include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=gnu++11") # Tests if(BUILD_TESTS) diff --git a/python/triton/kernel.py b/python/triton/kernel.py index b2afb4ef6..752bb30a7 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -18,72 +18,8 @@ import triton.utils import triton._C.libtriton as libtriton import os import time +import torch.utils.cpp_extension -@contextlib.contextmanager -def quiet(): - old_stdout, old_stderr = sys.stdout, sys.stderr - sys.stdout, sys.stderr = io.StringIO(), io.StringIO() - try: - yield - finally: - sys.stdout, sys.stderr = old_stdout, old_stderr - -def _build(src, path, name): - ccdir = os.path.join(libtriton.__file__, os.path.pardir) - ccdir = os.path.realpath(ccdir) - # include directories - triton_include_dirs = [os.path.join(ccdir, 'include')] - include_dirs = triton_include_dirs - # library directories - triton_library_dirs = [ccdir] - library_dirs = triton_library_dirs - # libraries - libraries = ['triton'] - # add framework - extra_compile_args = [] - if fw.has_torch(): - prefix = os.path.dirname(fw.torch.__file__) - library_dirs += [os.path.join(prefix, 'lib')] - include_dirs += ['/usr/local/cuda/include/', - os.path.join(prefix, 'lib', 'include'), - os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'), - os.path.join(prefix, 'include'), - os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')] - libraries += ['torch'] - abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI - extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)] - else: - assert False - # extra arguments - extra_link_args = [] - # dependences - depends = [os.path.realpath(libtriton.__file__)] - # create extension module - ext = setuptools.Extension( - name = name, - language = 'c++', - sources = [src], - include_dirs = include_dirs, - extra_compile_args = extra_compile_args + ['-g0'], - extra_link_args = extra_link_args, - library_dirs = library_dirs, - libraries = libraries, - depends = depends - ) - # build extension module - args = ['build_ext'] - tmp = tempfile.mkdtemp() - args.append('--build-temp=' + tmp) - args.append('--build-lib=' + path) - args.append('-q') - args = dict( - name = name, - ext_modules = [ext], - script_args = args, - ) - with quiet(): - setuptools.setup(**args) - shutil.rmtree(tmp) def _cvt_to_def_str(obj): # bool @@ -129,7 +65,7 @@ def _make_framework_op(arg_types): except FileExistsError: pass suffix = sysconfig.get_config_var('EXT_SUFFIX') - so = os.path.join(root, f'op{suffix}') + so = os.path.join(root, f'{name}.so') cpp = os.path.join(root, f'op.cpp') # handle cached .so file if os.path.exists(so) and os.stat(so).st_size > 0: @@ -141,27 +77,23 @@ def _make_framework_op(arg_types): return getattr(fw.torch.ops.triton, name) # create torch source code lock = os.path.join(root, f'lock') - try: - fd = os.open(lock, os.O_CREAT|os.O_EXCL|os.O_RDWR) - if os.path.exists(so): - fw.torch.ops.load_library(so) - os.remove(lock) - return getattr(fw.torch.ops.triton, name) - src, _ = libtriton.make_torch_src(name, arg_types) - with open(cpp, 'w+') as handle: - handle.writelines(src) - # create torch.so - _build(cpp, root, 'op') - fw.torch.ops.load_library(so) - os.remove(lock) - return getattr(fw.torch.ops.triton, name) - except FileExistsError: - # spin until .so is fully written - while os.path.exists(lock): - time.sleep(0.01) - fw.torch.ops.load_library(so) - return getattr(fw.torch.ops.triton, name) - + src, _ = libtriton.make_torch_src(name, arg_types) + with open(cpp, 'w+') as handle: + handle.writelines(src) + # create torch.so + src, _ = libtriton.make_torch_src(name, arg_types) + ccdir = os.path.join(libtriton.__file__, os.path.pardir) + ccdir = os.path.realpath(ccdir) + print('[TRITON] Compiling op...') + lib = torch.utils.cpp_extension.load_inline(name, src, + extra_ldflags = [f'-L{ccdir}', '-ltriton'], + extra_include_paths = [os.path.join(ccdir, 'include')], + extra_cflags = [f'-std=gnu++11'], + build_directory = root, + is_python_module = False, + with_cuda = True) + fw.torch.ops.load_library(so) + return getattr(fw.torch.ops.triton, name)