[TRITON][KERNEL] Fixed issue for concurrent compilation of torch

extensions
This commit is contained in:
Philippe Tillet
2020-06-24 13:51:46 -04:00
committed by Philippe Tillet
parent 8bdfbe2514
commit 955b027103

View File

@@ -8,6 +8,7 @@ import sys
import weakref import weakref
import contextlib import contextlib
import io import io
import torch.utils.cpp_extension
# import for just-in-time compilation # import for just-in-time compilation
import distutils import distutils
import setuptools.command.build_ext import setuptools.command.build_ext
@@ -18,9 +19,54 @@ import triton.utils
import triton._C.libtriton as libtriton import triton._C.libtriton as libtriton
import os import os
import time import time
import torch.utils.cpp_extension
import platform import platform
@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
try:
yield
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr
def _build(src, path, name):
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
ccdir = os.path.realpath(ccdir)
# include / libraries
include_dirs = [os.path.join(ccdir, 'include')]
library_dirs = [ccdir]
libraries = ['triton']
# create extension module
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
extra_compile_args = ['-fPIC', '-Wno-deprecated-declarations', '-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
extra_compile_args += ['-DTORCH_EXTENSION_NAME={}'.format(name)]
extra_compile_args += ['-DTORCH_API_INCLUDE_EXTENSION_H']
ext = torch.utils.cpp_extension.CUDAExtension(
name = name,
language = 'c++',
sources = [src],
include_dirs = include_dirs,
library_dirs = library_dirs,
libraries = libraries,
extra_compile_args = extra_compile_args,
depends = [os.path.realpath(libtriton.__file__)]
)
# build extension module
args = ['build_ext']
tmp = tempfile.mkdtemp()
args.append('--build-temp=' + tmp)
args.append('--build-lib=' + path)
args.append('-q')
args = dict(
name = name,
ext_modules = [ext],
script_args = args,
)
with quiet():
setuptools.setup(**args)
shutil.rmtree(tmp)
def _cvt_to_def_str(obj): def _cvt_to_def_str(obj):
# bool # bool
@@ -77,30 +123,42 @@ def _make_framework_op(arg_types):
fw.torch.ops.load_library(so) fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name) return getattr(fw.torch.ops.triton, name)
# create torch source code # create torch source code
lock = os.path.join(root, f'lock') print('[TRITON] Compiling op...')
baton = torch.utils.file_baton.FileBaton(os.path.join(root, 'lock'))
if baton.try_acquire():
try:
src, _ = libtriton.make_torch_src(name, arg_types) src, _ = libtriton.make_torch_src(name, arg_types)
with open(cpp, 'w+') as handle: with open(cpp, 'w+') as handle:
handle.writelines(src) handle.writelines(src)
# create torch.so
src, _ = libtriton.make_torch_src(name, arg_types)
ccdir = os.path.join(libtriton.__file__, os.path.pardir) ccdir = os.path.join(libtriton.__file__, os.path.pardir)
ccdir = os.path.realpath(ccdir) ccdir = os.path.realpath(ccdir)
print('[TRITON] Compiling op...') #include_dirs = [os.path.join(ccdir, 'include')]
#library_dirs = [ccdir]
#_build(cpp, root, 'op')
#libraries = ['triton']
machine = platform.machine() machine = platform.machine()
extra_cflags = ['-std=gnu++11'] if machine == 'ppc64le' else None torch.utils.cpp_extension._write_ninja_file_and_build(
lib = torch.utils.cpp_extension.load(name, cpp, name=name,
sources=[cpp],
extra_cflags=['-std=gnu++11'] if machine == 'ppc64le' else [],
extra_cuda_cflags=[],
extra_ldflags=[f'-L{ccdir}', '-ltriton'], extra_ldflags=[f'-L{ccdir}', '-ltriton'],
extra_include_paths=[os.path.join(ccdir, 'include')], extra_include_paths=[os.path.join(ccdir, 'include')],
extra_cflags = extra_cflags,
build_directory=root, build_directory=root,
is_python_module = False, verbose=False,
with_cuda=True) with_cuda=True)
finally:
baton.release()
else:
baton.wait()
print('[TRITON] Done compiling...')
fw.torch.ops.load_library(so) fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name) return getattr(fw.torch.ops.triton, name)
class kernel: class kernel:
def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]): def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]):