diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index d79e5d9d1..4060f23bb 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -1,4 +1,3 @@ -#include #include #include #include diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index c8bfedab4..4570a5c61 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1032,28 +1032,44 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda' kernel[(1, )](value, x) else: kernel[(1, )](value, x) -# ------------------------- -# test dynamic parallelism -# ------------------------- -@triton.jit -def mult(x, alpha): - tl.store(x + tl.program_id(0), alpha) +# ---------------- +# test constexpr +# ---------------- + +@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>']) +@pytest.mark.parametrize("is_lhs_constexpr", [False, True]) +@pytest.mark.parametrize("is_rhs_constexpr", [True, False]) +def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr): + + @triton.jit + def kernel(Z, X, Y): + x = tl.load(X) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z, z) + + x_str = "3.14" if is_lhs_constexpr else "x" + y_str = "4.13" if is_rhs_constexpr else "y" + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"}) + x = numpy_random((1,), dtype_str="float32") + y = numpy_random((1,), dtype_str="float32") + z = np.array(eval(f"{x_str} {op} {y_str}")) + x_tri = to_triton(x) + y_tri = to_triton(y) + z_tri = to_triton(np.empty((1,), dtype=z.dtype)) + kernel[(1,)](z_tri, x_tri, y_tri) + np.testing.assert_allclose(z, to_numpy(z_tri)) -@triton.jit -def stub(X, alpha, grid_0, grid_1, grid_2): - tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2]) +def test_constexpr_shape(): + @triton.jit + def kernel(X): + off = tl.arange(0, 128 + 128) + tl.store(X + off, off) -# def test_dyn_par(cond=True, device='cuda'): -# n_pids = 10 -# # pids = torch.arange(n_pids, device=device) -# # alpha = 2.0 -# # x_ref = pids * alpha -# x_tri = torch.full((10,), fill_value=-1., device=device) -# # cond = torch.tensor([cond], device=device) -# stub[(1,)](x_tri, 3.14, n_pids, 1, 1) -# print(x_tri) -# # triton.testing.assert_almost_equal(x_ref, x_tri) + x_tri = to_triton(np.empty((256, ), dtype=np.int32)) + kernel[(1,)](x_tri) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 51e3577ae..311fb85f0 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -390,12 +390,14 @@ class CodeGenerator(ast.NodeVisitor): return tuple(args) def visit_BinOp(self, node): + # visit operand lhs = self.visit(node.left) rhs = self.visit(node.right) - if isinstance(lhs, triton.language.constexpr): - lhs = lhs.value - if isinstance(rhs, triton.language.constexpr): - rhs = rhs.value + is_lhs_constexpr = isinstance(lhs, triton.language.constexpr) + is_rhs_constexpr = isinstance(rhs, triton.language.constexpr) + lhs = lhs.value if is_lhs_constexpr else lhs + rhs = rhs.value if is_rhs_constexpr else rhs + # get function name fn = { ast.Add: '__add__', ast.Sub: '__sub__', @@ -410,6 +412,10 @@ class CodeGenerator(ast.NodeVisitor): ast.BitOr: '__or__', ast.BitXor: '__xor__', }[type(node.op)] + # return a new constexpr if both arg are constexprs + if is_lhs_constexpr and is_rhs_constexpr: + return triton.language.constexpr(getattr(lhs, fn)(rhs)) + # call operator if is_triton_tensor(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) elif is_triton_tensor(rhs): @@ -468,14 +474,16 @@ class CodeGenerator(ast.NodeVisitor): assert len(node.ops) == 1 lhs = self.visit(node.left) rhs = self.visit(node.comparators[0]) - if isinstance(lhs, triton.language.constexpr): - lhs = lhs.value - if isinstance(rhs, triton.language.constexpr): - rhs = rhs.value + is_lhs_constexpr = isinstance(lhs, triton.language.constexpr) + is_rhs_constexpr = isinstance(rhs, triton.language.constexpr) + lhs = lhs.value if is_lhs_constexpr else lhs + rhs = rhs.value if is_rhs_constexpr else rhs + # handle `is`` and `is not`` if type(node.ops[0]) == ast.Is: return triton.language.constexpr(lhs is rhs) if type(node.ops[0]) == ast.IsNot: return triton.language.constexpr(lhs is not rhs) + # function name fn = { ast.Eq: '__eq__', ast.NotEq: '__ne__', @@ -484,29 +492,32 @@ class CodeGenerator(ast.NodeVisitor): ast.Gt: '__gt__', ast.GtE: '__ge__', }[type(node.ops[0])] + # return a new constexpr if both arg are constexprs + if is_lhs_constexpr and is_rhs_constexpr: + return triton.language.constexpr(getattr(lhs, fn)(rhs)) + # call operator if is_triton_tensor(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) elif is_triton_tensor(rhs): fn = fn[:2] + 'r' + fn[2:] return getattr(rhs, fn)(lhs, _builder=self.builder) else: - return getattr(lhs, fn)(rhs) + assert False def visit_UnaryOp(self, node): op = self.visit(node.operand) if type(node.op) == ast.Not: assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" return triton.language.constexpr(not op) - if isinstance(op, triton.language.constexpr): - op = op.value fn = { ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Invert: '__invert__', }[type(node.op)] - if is_triton_tensor(op): - return getattr(op, fn)(_builder=self.builder) - return getattr(op, fn)() + if isinstance(op, triton.language.constexpr): + return triton.language.constexpr(getattr(op.value, fn)()) + assert is_triton_tensor(op) + return getattr(op, fn)(_builder=self.builder) def visit_While(self, node): current_bb = self.builder.get_insert_block() @@ -656,6 +667,10 @@ class CodeGenerator(ast.NodeVisitor): args = [arg.value if isinstance(arg, triton.language.constexpr) else arg for arg in args] ret = fn(*args, **kws) + if isinstance(ret, (bool, int, float)): + ret = triton.language.core.constexpr(ret) + else: + ret = triton.language.core._to_tensor(ret, self.builder) # special case: dynamic parallelism # in this case the core primitive returns a proxy # if isinstance(ret, triton.language.core.LaunchProxy): diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 6aa1b68cd..75d75e8ea 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -337,68 +337,6 @@ class constexpr: def __repr__(self) -> str: return f"constexpr[{self.value}]" - def __add__(self, other): - return self.value + other.value - - def __radd__(self, other): - return other.value + self.value - - def __sub__(self, other): - return self.value - other.value - - def __rsub__(self, other): - return other.value - self.value - - def __mul__(self, other): - return self.value * other.value - - def __rmul__(self, other): - return other.value * self.value - - def __truediv__(self, other): - return self.value / other.value - - def __rtruediv__(self, other): - return other.value / self.value - - def __floordiv__(self, other): - return self.value // other.value - - def __rfloordiv__(self, other): - return other.value // self.value - - # - - def __gt__(self, other): - return self.value > other.value - - def __rgt__(self, other): - return other.value > self.value - - def __ge__(self, other): - return self.value >= other.value - - def __rge__(self, other): - return other.value >= self.value - - def __lt__(self, other): - return self.value < other.value - - def __rlt__(self, other): - return other.value < self.value - - def __le__(self, other): - return self.value <= other.value - - def __rle__(self, other): - return other.value <= self.value - - def __eq__(self, other): - return self.value == other.value - - def __ne__(self, other): - return self.value != other.value - def __bool__(self): return bool(self.value) @@ -496,6 +434,11 @@ class tensor: other = _to_tensor(other, _builder) return semantic.mod(self, other, _builder) + @builtin + def __rmod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(other, self, _builder) + # unary operators @builtin def __neg__(self, _builder=None): @@ -564,6 +507,7 @@ class tensor: @builtin def __rlt__(self, other, _builder=None): + other = _to_tensor(other, _builder) return semantic.less_than(other, self, _builder) # <=