[TRITON][KERNEL] Fixed issue for concurrent compilation of torch
extensions
This commit is contained in:
committed by
Philippe Tillet
parent
8bdfbe2514
commit
955b027103
@@ -8,6 +8,7 @@ import sys
|
|||||||
import weakref
|
import weakref
|
||||||
import contextlib
|
import contextlib
|
||||||
import io
|
import io
|
||||||
|
import torch.utils.cpp_extension
|
||||||
# import for just-in-time compilation
|
# import for just-in-time compilation
|
||||||
import distutils
|
import distutils
|
||||||
import setuptools.command.build_ext
|
import setuptools.command.build_ext
|
||||||
@@ -18,9 +19,54 @@ import triton.utils
|
|||||||
import triton._C.libtriton as libtriton
|
import triton._C.libtriton as libtriton
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import torch.utils.cpp_extension
|
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def quiet():
|
||||||
|
old_stdout, old_stderr = sys.stdout, sys.stderr
|
||||||
|
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
sys.stdout, sys.stderr = old_stdout, old_stderr
|
||||||
|
|
||||||
|
def _build(src, path, name):
|
||||||
|
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
|
||||||
|
ccdir = os.path.realpath(ccdir)
|
||||||
|
# include / libraries
|
||||||
|
include_dirs = [os.path.join(ccdir, 'include')]
|
||||||
|
library_dirs = [ccdir]
|
||||||
|
libraries = ['triton']
|
||||||
|
# create extension module
|
||||||
|
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
|
||||||
|
extra_compile_args = ['-fPIC', '-Wno-deprecated-declarations', '-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
|
||||||
|
extra_compile_args += ['-DTORCH_EXTENSION_NAME={}'.format(name)]
|
||||||
|
extra_compile_args += ['-DTORCH_API_INCLUDE_EXTENSION_H']
|
||||||
|
|
||||||
|
ext = torch.utils.cpp_extension.CUDAExtension(
|
||||||
|
name = name,
|
||||||
|
language = 'c++',
|
||||||
|
sources = [src],
|
||||||
|
include_dirs = include_dirs,
|
||||||
|
library_dirs = library_dirs,
|
||||||
|
libraries = libraries,
|
||||||
|
extra_compile_args = extra_compile_args,
|
||||||
|
depends = [os.path.realpath(libtriton.__file__)]
|
||||||
|
)
|
||||||
|
# build extension module
|
||||||
|
args = ['build_ext']
|
||||||
|
tmp = tempfile.mkdtemp()
|
||||||
|
args.append('--build-temp=' + tmp)
|
||||||
|
args.append('--build-lib=' + path)
|
||||||
|
args.append('-q')
|
||||||
|
args = dict(
|
||||||
|
name = name,
|
||||||
|
ext_modules = [ext],
|
||||||
|
script_args = args,
|
||||||
|
)
|
||||||
|
with quiet():
|
||||||
|
setuptools.setup(**args)
|
||||||
|
shutil.rmtree(tmp)
|
||||||
|
|
||||||
def _cvt_to_def_str(obj):
|
def _cvt_to_def_str(obj):
|
||||||
# bool
|
# bool
|
||||||
@@ -77,30 +123,42 @@ def _make_framework_op(arg_types):
|
|||||||
fw.torch.ops.load_library(so)
|
fw.torch.ops.load_library(so)
|
||||||
return getattr(fw.torch.ops.triton, name)
|
return getattr(fw.torch.ops.triton, name)
|
||||||
# create torch source code
|
# create torch source code
|
||||||
lock = os.path.join(root, f'lock')
|
|
||||||
src, _ = libtriton.make_torch_src(name, arg_types)
|
|
||||||
with open(cpp, 'w+') as handle:
|
|
||||||
handle.writelines(src)
|
|
||||||
# create torch.so
|
|
||||||
src, _ = libtriton.make_torch_src(name, arg_types)
|
|
||||||
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
|
|
||||||
ccdir = os.path.realpath(ccdir)
|
|
||||||
print('[TRITON] Compiling op...')
|
print('[TRITON] Compiling op...')
|
||||||
machine = platform.machine()
|
baton = torch.utils.file_baton.FileBaton(os.path.join(root, 'lock'))
|
||||||
extra_cflags = ['-std=gnu++11'] if machine == 'ppc64le' else None
|
if baton.try_acquire():
|
||||||
lib = torch.utils.cpp_extension.load(name, cpp,
|
try:
|
||||||
extra_ldflags = [f'-L{ccdir}', '-ltriton'],
|
src, _ = libtriton.make_torch_src(name, arg_types)
|
||||||
extra_include_paths = [os.path.join(ccdir, 'include')],
|
with open(cpp, 'w+') as handle:
|
||||||
extra_cflags = extra_cflags,
|
handle.writelines(src)
|
||||||
build_directory = root,
|
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
|
||||||
is_python_module = False,
|
ccdir = os.path.realpath(ccdir)
|
||||||
with_cuda = True)
|
#include_dirs = [os.path.join(ccdir, 'include')]
|
||||||
|
#library_dirs = [ccdir]
|
||||||
|
#_build(cpp, root, 'op')
|
||||||
|
#libraries = ['triton']
|
||||||
|
machine = platform.machine()
|
||||||
|
torch.utils.cpp_extension._write_ninja_file_and_build(
|
||||||
|
name=name,
|
||||||
|
sources=[cpp],
|
||||||
|
extra_cflags=['-std=gnu++11'] if machine == 'ppc64le' else [],
|
||||||
|
extra_cuda_cflags=[],
|
||||||
|
extra_ldflags=[f'-L{ccdir}', '-ltriton'],
|
||||||
|
extra_include_paths=[os.path.join(ccdir, 'include')],
|
||||||
|
build_directory=root,
|
||||||
|
verbose=False,
|
||||||
|
with_cuda=True)
|
||||||
|
finally:
|
||||||
|
baton.release()
|
||||||
|
else:
|
||||||
|
baton.wait()
|
||||||
|
print('[TRITON] Done compiling...')
|
||||||
fw.torch.ops.load_library(so)
|
fw.torch.ops.load_library(so)
|
||||||
return getattr(fw.torch.ops.triton, name)
|
return getattr(fw.torch.ops.triton, name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class kernel:
|
class kernel:
|
||||||
|
|
||||||
def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]):
|
def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]):
|
||||||
|
Reference in New Issue
Block a user