From e61dc7594275e25fbd45c25f69848a026a2e29ab Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 30 Oct 2022 14:10:02 -0700 Subject: [PATCH] [FRONTEND] Fixed inliner and got more tests to pass (#822) This adds a `DialectInlinerInterface` to the Triton dialect. This, along with a few other minor semantic changes, fixes our tests on call instructions. Also added the option to provide use an "LLVM_SYSPATH" environment variable to link against locally build of LLVM; this was useful for debugging this issue. --- .../triton/Dialect/Triton/Transforms/Passes.h | 2 +- lib/Dialect/Triton/IR/Dialect.cpp | 20 ++ python/setup.py | 7 +- python/src/triton.cc | 7 +- python/tests/test_core.py | 300 +++++++++--------- python/triton/compiler.py | 5 +- python/triton/language/core.py | 6 + 7 files changed, 192 insertions(+), 155 deletions(-) diff --git a/include/triton/Dialect/Triton/Transforms/Passes.h b/include/triton/Dialect/Triton/Transforms/Passes.h index 5dae1a498..c4176d3a6 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.h +++ b/include/triton/Dialect/Triton/Transforms/Passes.h @@ -8,7 +8,7 @@ namespace triton { std::unique_ptr createCombineOpsPass(); -} +} // namespace triton #define GEN_PASS_REGISTRATION #include "triton/Dialect/Triton/Transforms/Passes.h.inc" diff --git a/lib/Dialect/Triton/IR/Dialect.cpp b/lib/Dialect/Triton/IR/Dialect.cpp index 14fb30c21..437de2603 100644 --- a/lib/Dialect/Triton/IR/Dialect.cpp +++ b/lib/Dialect/Triton/IR/Dialect.cpp @@ -8,11 +8,30 @@ #include "mlir/IR/DialectImplementation.h" +#include "mlir/Transforms/InliningUtils.h" #include "triton/Dialect/Triton/IR/Dialect.cpp.inc" using namespace mlir; using namespace mlir::triton; +//===----------------------------------------------------------------------===// +// TritonDialect Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +struct TritonInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + BlockAndValueMapping &valueMapping) const final { + return true; + } + bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, + BlockAndValueMapping &) const final { + return true; + } +}; +} // namespace + void TritonDialect::initialize() { registerTypes(); @@ -22,6 +41,7 @@ void TritonDialect::initialize() { >(); // We can also add interface here. + addInterfaces(); } Operation *TritonDialect::materializeConstant(OpBuilder &builder, diff --git a/python/setup.py b/python/setup.py index 6e4544d97..a39aa1e5b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -38,12 +38,13 @@ class Package(NamedTuple): test_file: str include_flag: str lib_flag: str + syspath_var_name: str def get_pybind11_package_info(): name = "pybind11-2.10.0" url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz" - return Package("pybind11", name, url, "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "") + return Package("pybind11", name, url, "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH") def get_llvm_package_info(): @@ -57,7 +58,7 @@ def get_llvm_package_info(): else: name = 'clang+llvm-14.0.0-x86_64-{}'.format(system_suffix) url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{}.tar.xz".format(name) - return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR") + return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") def get_thirdparty_packages(triton_cache_path): @@ -67,6 +68,8 @@ def get_thirdparty_packages(triton_cache_path): package_root_dir = os.path.join(triton_cache_path, p.package) package_dir = os.path.join(package_root_dir, p.name) test_file_path = os.path.join(package_dir, p.test_file) + if p.syspath_var_name in os.environ: + package_dir = os.environ[p.syspath_var_name] if not os.path.exists(test_file_path): try: shutil.rmtree(package_root_dir) diff --git a/python/src/triton.cc b/python/src/triton.cc index 66ae425e6..24fc6406a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -422,7 +422,12 @@ void init_triton_ir(py::module &&m) { .def("get_int32_attr", &mlir::OpBuilder::getI32IntegerAttr) // Use arith.ConstantOp to create constants // // Constants - // .def("get_int1", &ir::builder::get_int1, ret::reference) + .def("get_int1", + [](mlir::OpBuilder &self, bool v) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return mlir::Value(self.create( + loc, v, self.getI1Type())); + }) .def("get_int32", [](mlir::OpBuilder &self, int64_t v) -> mlir::Value { auto loc = self.getUnknownLoc(); diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 56b1f36db..4a9b9ed10 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -1177,20 +1177,20 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # # --------------- -# @pytest.mark.parametrize("start", [0, 1, 7, 16]) -# def test_arange(start, device='cuda'): -# BLOCK = 128 -# z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) +@pytest.mark.parametrize("start", [0, 1, 7, 16]) +def test_arange(start, device='cuda'): + BLOCK = 128 + z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) -# @triton.jit -# def _kernel(z, BLOCK: tl.constexpr, -# START: tl.constexpr, END: tl.constexpr): -# off = tl.arange(0, BLOCK) -# val = tl.arange(START, END) -# tl.store(z + off, val) -# _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK) -# z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) -# triton.testing.assert_almost_equal(z_tri, z_ref) + @triton.jit + def _kernel(z, BLOCK: tl.constexpr, + START: tl.constexpr, END: tl.constexpr): + off = tl.arange(0, BLOCK) + val = tl.arange(START, END) + tl.store(z + off, val) + _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK) + z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) + triton.testing.assert_almost_equal(z_tri, z_ref) # # --------------- # # test load @@ -1248,47 +1248,47 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # triton.testing.allclose(out, reference_out) -# @pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) -# def test_load_cache_modifier(cache): -# src = torch.empty(128, device='cuda') -# dst = torch.empty(128, device='cuda') +@pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) +def test_load_cache_modifier(cache): + src = torch.empty(128, device='cuda') + dst = torch.empty(128, device='cuda') -# @triton.jit -# def _kernel(dst, src, CACHE: tl.constexpr): -# offsets = tl.arange(0, 128) -# x = tl.load(src + offsets, cache_modifier=CACHE) -# tl.store(dst + offsets, x) + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets, cache_modifier=CACHE) + tl.store(dst + offsets, x) -# pgm = _kernel[(1,)](dst, src, CACHE=cache) -# ptx = pgm.asm['ptx'] -# if cache == '': -# assert 'ld.global.ca' not in ptx -# assert 'ld.global.cg' not in ptx -# if cache == '.cg': -# assert 'ld.global.cg' in ptx -# assert 'ld.global.ca' not in ptx -# if cache == '.ca': -# assert 'ld.global.ca' in ptx -# assert 'ld.global.cg' not in ptx + pgm = _kernel[(1,)](dst, src, CACHE=cache) + ptx = pgm.asm['ptx'] + if cache == '': + assert 'ld.global.ca' not in ptx + assert 'ld.global.cg' not in ptx + if cache == '.cg': + assert 'ld.global.cg' in ptx + assert 'ld.global.ca' not in ptx + if cache == '.ca': + assert 'ld.global.ca' in ptx + assert 'ld.global.cg' not in ptx -# @pytest.mark.parametrize("N", [16, 10, 11, 1024]) -# def test_vectorization(N): -# src = torch.empty(1024, device='cuda') -# dst = torch.empty(1024, device='cuda') +@pytest.mark.parametrize("N", [16, 10, 11, 1024]) +def test_vectorization(N): + src = torch.empty(1024, device='cuda') + dst = torch.empty(1024, device='cuda') -# @triton.jit -# def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): -# offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) -# x = tl.load(src + offsets, mask=offsets < N) -# tl.store(dst + offsets, x, mask=offsets < N) -# pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0]) -# ptx = pgm.asm["ptx"] -# if N % 16 == 0: -# assert "ld.global.v4.b32" in ptx -# else: -# assert "ld.global.b32" in ptx -# # triton.testing.assert_almost_equal(dst, src[:N]) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0]) + ptx = pgm.asm["ptx"] + if N % 16 == 0: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.b32" in ptx + # triton.testing.assert_almost_equal(dst, src[:N]) # # --------------- # # test store # # --------------- @@ -1335,145 +1335,149 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # # ---------------- -# def test_noop(device='cuda'): -# @triton.jit -# def kernel(x): -# pass -# x = to_triton(numpy_random((1,), dtype_str='int32'), device=device) -# kernel[(1, )](x) +def test_noop(device='cuda'): + @triton.jit + def kernel(x): + pass + x = to_triton(numpy_random((1,), dtype_str='int32'), device=device) + kernel[(1, )](x) -# @pytest.mark.parametrize("value, value_type", [ -# (-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), -# (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'), -# (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64') -# ]) -# def test_value_specialization(value: int, value_type: str, device='cuda') -> None: -# spec_type = None +@pytest.mark.parametrize("value, value_type", [ + (-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64') +]) +def test_value_specialization(value: int, value_type: str, device='cuda') -> None: + spec_type = None -# def cache_hook(*args, **kwargs): -# nonlocal spec_type -# spec_type = kwargs["compile"]["signature"][0] -# JITFunction.cache_hook = cache_hook + def cache_hook(*args, **kwargs): + nonlocal spec_type + spec_type = kwargs["compile"]["signature"][0] + JITFunction.cache_hook = cache_hook -# @triton.jit -# def kernel(VALUE, X): -# pass + @triton.jit + def kernel(VALUE, X): + pass -# x = torch.tensor([3.14159], device='cuda') -# pgm = kernel[(1, )](value, x) + x = torch.tensor([3.14159], device='cuda') + pgm = kernel[(1, )](value, x) -# JITFunction.cache_hook = None -# assert spec_type == value_type + JITFunction.cache_hook = None + assert spec_type == value_type + +# # -------------------- +# # value specialization +# # -------------------- -# @pytest.mark.parametrize( -# "value, overflow", -# [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)] -# ) -# def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None: +@pytest.mark.parametrize( + "value, overflow", + [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)] +) +def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None: -# @triton.jit -# def kernel(VALUE, X): -# pass + @triton.jit + def kernel(VALUE, X): + pass -# x = torch.tensor([3.14159], device='cuda') + x = torch.tensor([3.14159], device='cuda') -# if overflow: -# with pytest.raises(OverflowError): -# kernel[(1, )](value, x) -# else: -# kernel[(1, )](value, x) + if overflow: + with pytest.raises(OverflowError): + kernel[(1, )](value, x) + else: + kernel[(1, )](value, x) # # ---------------- # # test constexpr # # ---------------- -# @pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>']) -# @pytest.mark.parametrize("is_lhs_constexpr", [False, True]) -# @pytest.mark.parametrize("is_rhs_constexpr", [True, False]) -# def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr): +@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>']) +@pytest.mark.parametrize("is_lhs_constexpr", [False, True]) +@pytest.mark.parametrize("is_rhs_constexpr", [True, False]) +def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr): -# @triton.jit -# def kernel(Z, X, Y): -# x = tl.load(X) -# y = tl.load(Y) -# z = GENERATE_TEST_HERE -# tl.store(Z, z) + @triton.jit + def kernel(Z, X, Y): + x = tl.load(X) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z, z) -# x_str = "3.14" if is_lhs_constexpr else "x" -# y_str = "4.13" if is_rhs_constexpr else "y" -# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"}) -# x = numpy_random((1,), dtype_str="float32") -# y = numpy_random((1,), dtype_str="float32") -# z = np.array(eval(f"{x_str} {op} {y_str}")) -# x_tri = to_triton(x) -# y_tri = to_triton(y) -# z_tri = to_triton(np.empty((1,), dtype=z.dtype)) -# kernel[(1,)](z_tri, x_tri, y_tri) -# np.testing.assert_allclose(z, to_numpy(z_tri)) + x_str = "3.14" if is_lhs_constexpr else "x" + y_str = "4.13" if is_rhs_constexpr else "y" + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"}) + x = numpy_random((1,), dtype_str="float32") + y = numpy_random((1,), dtype_str="float32") + z = np.array(eval(f"{x_str} {op} {y_str}")) + x_tri = to_triton(x) + y_tri = to_triton(y) + z_tri = to_triton(np.empty((1,), dtype=z.dtype)) + kernel[(1,)](z_tri, x_tri, y_tri) + np.testing.assert_allclose(z, to_numpy(z_tri)) -# def test_constexpr_shape(): +def test_constexpr_shape(): -# @triton.jit -# def kernel(X): -# off = tl.arange(0, 128 + 128) -# tl.store(X + off, off) + @triton.jit + def kernel(X): + off = tl.arange(0, 128 + 128) + tl.store(X + off, off) -# x_tri = to_triton(np.empty((256, ), dtype=np.int32)) -# kernel[(1,)](x_tri) -# np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) + x_tri = to_triton(np.empty((256, ), dtype=np.int32)) + kernel[(1,)](x_tri) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) -# def test_constexpr_scalar_shape(): +def test_constexpr_scalar_shape(): -# @triton.jit -# def kernel(X, s): -# off = tl.arange(0, 256) -# val = off % (256 // s) -# tl.store(X + off, val) + @triton.jit + def kernel(X, s): + off = tl.arange(0, 256) + val = off % (256 // s) + tl.store(X + off, val) -# x_tri = to_triton(np.empty((256, ), dtype=np.int32)) -# kernel[(1,)](x_tri, 32) -# np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) + x_tri = to_triton(np.empty((256, ), dtype=np.int32)) + kernel[(1,)](x_tri, 32) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) # # ------------- # # test call # # ------------- -# @triton.jit -# def val_multiplier(val, i): -# return val * i +@triton.jit +def val_multiplier(val, i): + return val * i -# @triton.jit -# def vecmul_kernel(ptr, n_elements, rep): -# pid = tl.program_id(axis=0) -# offsets = pid * 128 + tl.arange(0, 128) -# mask = offsets < n_elements -# vec = tl.load(ptr + offsets, mask=mask) -# for i in range(1, rep): -# vec = val_multiplier(vec, i) -# tl.store(ptr + offsets, vec, mask=mask) +@triton.jit +def vecmul_kernel(ptr, n_elements, rep): + pid = tl.program_id(axis=0) + offsets = pid * 128 + tl.arange(0, 128) + mask = offsets < n_elements + vec = tl.load(ptr + offsets, mask=mask) + for i in range(1, rep): + vec = val_multiplier(vec, i) + tl.store(ptr + offsets, vec, mask=mask) -# def test_call(): +def test_call(): -# @triton.jit -# def kernel(ptr, n_elements, num1, num2): -# vecmul_kernel(ptr, n_elements, num1) -# vecmul_kernel(ptr, n_elements, num2) + @triton.jit + def kernel(ptr, n_elements, num1, num2): + vecmul_kernel(ptr, n_elements, num1) + vecmul_kernel(ptr, n_elements, num2) -# size = 1024 -# rand_val = numpy_random((size,), dtype_str="float32") -# rand_val_tri = to_triton(rand_val, device='cuda') -# kernel[(size // 128,)](rand_val_tri, size, 3, 5) + size = 1024 + rand_val = numpy_random((size,), dtype_str="float32") + rand_val_tri = to_triton(rand_val, device='cuda') + kernel[(size // 128,)](rand_val_tri, size, 3, 5) -# ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 -# np.testing.assert_equal(to_numpy(rand_val_tri), ans) + ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 + np.testing.assert_equal(to_numpy(rand_val_tri), ans) # # ------------- # # test if diff --git a/python/triton/compiler.py b/python/triton/compiler.py index f684ff691..8c93360ec 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -685,8 +685,7 @@ class CodeGenerator(ast.NodeVisitor): fn_name = mangle_fn(fn.__name__, arg_types, constants) # generate function def if necessary if not self.module.has_function(fn_name): - ret_type = triton.language.void - prototype = triton.language.function_type([ret_type], arg_types) + prototype = triton.language.function_type([], arg_types) gscope = sys.modules[fn.fn.__module__].__dict__ generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types) generator.visit(fn.parse()) @@ -696,7 +695,7 @@ class CodeGenerator(ast.NodeVisitor): callee_ret_type = self.function_ret_types[fn_name] symbol = self.module.get_function(fn_name) call_op = self.builder.call(symbol, arg_vals) - if call_op.get_num_results() == 0: + if call_op.get_num_results() == 0 or callee_ret_type is None: return None elif call_op.get_num_results() == 1: return triton.language.tensor(call_op.get_result(0), callee_ret_type) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e7f6a744d..8c2708074 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -473,6 +473,11 @@ class tensor: other = _to_tensor(other, _builder) return semantic.mod(self, other, _builder) + @builtin + def __rmod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(other, self, _builder) + # unary operators @builtin def __neg__(self, _builder=None): @@ -541,6 +546,7 @@ class tensor: @builtin def __rlt__(self, other, _builder=None): + other = _to_tensor(other, _builder) return semantic.less_than(other, self, _builder) # <=