[GENERAL] Added compatibility with pytorch 1.2.0 and powerpc
This commit is contained in:
committed by
Philippe Tillet
parent
9984ee8c7a
commit
04a9ea060b
@@ -21,7 +21,7 @@ endif()
|
|||||||
|
|
||||||
# Compiler flags
|
# Compiler flags
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=gnu++11")
|
||||||
|
|
||||||
# Tests
|
# Tests
|
||||||
if(BUILD_TESTS)
|
if(BUILD_TESTS)
|
||||||
|
@@ -18,72 +18,8 @@ import triton.utils
|
|||||||
import triton._C.libtriton as libtriton
|
import triton._C.libtriton as libtriton
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import torch.utils.cpp_extension
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def quiet():
|
|
||||||
old_stdout, old_stderr = sys.stdout, sys.stderr
|
|
||||||
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
sys.stdout, sys.stderr = old_stdout, old_stderr
|
|
||||||
|
|
||||||
def _build(src, path, name):
|
|
||||||
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
|
|
||||||
ccdir = os.path.realpath(ccdir)
|
|
||||||
# include directories
|
|
||||||
triton_include_dirs = [os.path.join(ccdir, 'include')]
|
|
||||||
include_dirs = triton_include_dirs
|
|
||||||
# library directories
|
|
||||||
triton_library_dirs = [ccdir]
|
|
||||||
library_dirs = triton_library_dirs
|
|
||||||
# libraries
|
|
||||||
libraries = ['triton']
|
|
||||||
# add framework
|
|
||||||
extra_compile_args = []
|
|
||||||
if fw.has_torch():
|
|
||||||
prefix = os.path.dirname(fw.torch.__file__)
|
|
||||||
library_dirs += [os.path.join(prefix, 'lib')]
|
|
||||||
include_dirs += ['/usr/local/cuda/include/',
|
|
||||||
os.path.join(prefix, 'lib', 'include'),
|
|
||||||
os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'),
|
|
||||||
os.path.join(prefix, 'include'),
|
|
||||||
os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')]
|
|
||||||
libraries += ['torch']
|
|
||||||
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
|
|
||||||
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
|
|
||||||
else:
|
|
||||||
assert False
|
|
||||||
# extra arguments
|
|
||||||
extra_link_args = []
|
|
||||||
# dependences
|
|
||||||
depends = [os.path.realpath(libtriton.__file__)]
|
|
||||||
# create extension module
|
|
||||||
ext = setuptools.Extension(
|
|
||||||
name = name,
|
|
||||||
language = 'c++',
|
|
||||||
sources = [src],
|
|
||||||
include_dirs = include_dirs,
|
|
||||||
extra_compile_args = extra_compile_args + ['-g0'],
|
|
||||||
extra_link_args = extra_link_args,
|
|
||||||
library_dirs = library_dirs,
|
|
||||||
libraries = libraries,
|
|
||||||
depends = depends
|
|
||||||
)
|
|
||||||
# build extension module
|
|
||||||
args = ['build_ext']
|
|
||||||
tmp = tempfile.mkdtemp()
|
|
||||||
args.append('--build-temp=' + tmp)
|
|
||||||
args.append('--build-lib=' + path)
|
|
||||||
args.append('-q')
|
|
||||||
args = dict(
|
|
||||||
name = name,
|
|
||||||
ext_modules = [ext],
|
|
||||||
script_args = args,
|
|
||||||
)
|
|
||||||
with quiet():
|
|
||||||
setuptools.setup(**args)
|
|
||||||
shutil.rmtree(tmp)
|
|
||||||
|
|
||||||
def _cvt_to_def_str(obj):
|
def _cvt_to_def_str(obj):
|
||||||
# bool
|
# bool
|
||||||
@@ -129,7 +65,7 @@ def _make_framework_op(arg_types):
|
|||||||
except FileExistsError:
|
except FileExistsError:
|
||||||
pass
|
pass
|
||||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||||
so = os.path.join(root, f'op{suffix}')
|
so = os.path.join(root, f'{name}.so')
|
||||||
cpp = os.path.join(root, f'op.cpp')
|
cpp = os.path.join(root, f'op.cpp')
|
||||||
# handle cached .so file
|
# handle cached .so file
|
||||||
if os.path.exists(so) and os.stat(so).st_size > 0:
|
if os.path.exists(so) and os.stat(so).st_size > 0:
|
||||||
@@ -141,27 +77,23 @@ def _make_framework_op(arg_types):
|
|||||||
return getattr(fw.torch.ops.triton, name)
|
return getattr(fw.torch.ops.triton, name)
|
||||||
# create torch source code
|
# create torch source code
|
||||||
lock = os.path.join(root, f'lock')
|
lock = os.path.join(root, f'lock')
|
||||||
try:
|
src, _ = libtriton.make_torch_src(name, arg_types)
|
||||||
fd = os.open(lock, os.O_CREAT|os.O_EXCL|os.O_RDWR)
|
with open(cpp, 'w+') as handle:
|
||||||
if os.path.exists(so):
|
handle.writelines(src)
|
||||||
fw.torch.ops.load_library(so)
|
# create torch.so
|
||||||
os.remove(lock)
|
src, _ = libtriton.make_torch_src(name, arg_types)
|
||||||
return getattr(fw.torch.ops.triton, name)
|
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
|
||||||
src, _ = libtriton.make_torch_src(name, arg_types)
|
ccdir = os.path.realpath(ccdir)
|
||||||
with open(cpp, 'w+') as handle:
|
print('[TRITON] Compiling op...')
|
||||||
handle.writelines(src)
|
lib = torch.utils.cpp_extension.load_inline(name, src,
|
||||||
# create torch.so
|
extra_ldflags = [f'-L{ccdir}', '-ltriton'],
|
||||||
_build(cpp, root, 'op')
|
extra_include_paths = [os.path.join(ccdir, 'include')],
|
||||||
fw.torch.ops.load_library(so)
|
extra_cflags = [f'-std=gnu++11'],
|
||||||
os.remove(lock)
|
build_directory = root,
|
||||||
return getattr(fw.torch.ops.triton, name)
|
is_python_module = False,
|
||||||
except FileExistsError:
|
with_cuda = True)
|
||||||
# spin until .so is fully written
|
fw.torch.ops.load_library(so)
|
||||||
while os.path.exists(lock):
|
return getattr(fw.torch.ops.triton, name)
|
||||||
time.sleep(0.01)
|
|
||||||
fw.torch.ops.load_library(so)
|
|
||||||
return getattr(fw.torch.ops.triton, name)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user