diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 21c0c4a34..19729f677 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -39,7 +39,7 @@ def _build(src, path, name): libraries = ['triton'] # create extension module abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI - extra_compile_args = ['-fPIC', '-Wno-deprecated-declarations', '-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)] + extra_compile_args = ['-fPIC', '-Wno-deprecated-declarations', f'-D_GLIBCXX_USE_CXX11_ABI={str(int(abi))}'] extra_compile_args += ['-DTORCH_EXTENSION_NAME={}'.format(name)] extra_compile_args += ['-DTORCH_API_INCLUDE_EXTENSION_H'] @@ -112,7 +112,7 @@ def _make_framework_op(arg_types): except FileExistsError: pass suffix = sysconfig.get_config_var('EXT_SUFFIX') - so = os.path.join(root, f'{name}.so') + so = os.path.join(root, f'op{suffix}') cpp = os.path.join(root, f'op.cpp') # handle cached .so file if os.path.exists(so) and os.stat(so).st_size > 0: @@ -132,21 +132,7 @@ def _make_framework_op(arg_types): handle.writelines(src) ccdir = os.path.join(libtriton.__file__, os.path.pardir) ccdir = os.path.realpath(ccdir) - #include_dirs = [os.path.join(ccdir, 'include')] - #library_dirs = [ccdir] - #_build(cpp, root, 'op') - #libraries = ['triton'] - machine = platform.machine() - torch.utils.cpp_extension._write_ninja_file_and_build( - name=name, - sources=[cpp], - extra_cflags=['-std=gnu++11'] if machine == 'ppc64le' else [], - extra_cuda_cflags=[], - extra_ldflags=[f'-L{ccdir}', '-ltriton'], - extra_include_paths=[os.path.join(ccdir, 'include')], - build_directory=root, - verbose=False, - with_cuda=True) + _build(cpp, root, 'op') finally: baton.release() else: