diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 6ef630c2a..0b4318e31 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -19,6 +19,7 @@ import triton._C.libtriton as libtriton import os import time import torch.utils.cpp_extension +import platform def _cvt_to_def_str(obj): @@ -85,10 +86,12 @@ def _make_framework_op(arg_types): ccdir = os.path.join(libtriton.__file__, os.path.pardir) ccdir = os.path.realpath(ccdir) print('[TRITON] Compiling op...') + machine = platform.machine() + extra_cflags = ['-std=gnu++11'] if machine == 'ppc64le' else None lib = torch.utils.cpp_extension.load(name, cpp, extra_ldflags = [f'-L{ccdir}', '-ltriton'], extra_include_paths = [os.path.join(ccdir, 'include')], - #extra_cflags = [f'-std=gnu++11'], + extra_cflags = extra_cflags, build_directory = root, is_python_module = False, with_cuda = True) @@ -146,4 +149,4 @@ class kernel: # launch self.fw_op(self.op_id, device, bench, bench_id, *args) if bench > 0: - return libtriton.retrieve_scalar(bench_id) \ No newline at end of file + return libtriton.retrieve_scalar(bench_id)