[FRONTEND] Improved constexpr handling (#493)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
#include <bits/types/clock_t.h>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
@@ -1032,28 +1032,44 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
|
||||
kernel[(1, )](value, x)
|
||||
else:
|
||||
kernel[(1, )](value, x)
|
||||
# -------------------------
|
||||
# test dynamic parallelism
|
||||
# -------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def mult(x, alpha):
|
||||
tl.store(x + tl.program_id(0), alpha)
|
||||
# ----------------
|
||||
# test constexpr
|
||||
# ----------------
|
||||
|
||||
@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>'])
|
||||
@pytest.mark.parametrize("is_lhs_constexpr", [False, True])
|
||||
@pytest.mark.parametrize("is_rhs_constexpr", [True, False])
|
||||
def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr):
|
||||
|
||||
@triton.jit
|
||||
def kernel(Z, X, Y):
|
||||
x = tl.load(X)
|
||||
y = tl.load(Y)
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z, z)
|
||||
|
||||
x_str = "3.14" if is_lhs_constexpr else "x"
|
||||
y_str = "4.13" if is_rhs_constexpr else "y"
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"})
|
||||
x = numpy_random((1,), dtype_str="float32")
|
||||
y = numpy_random((1,), dtype_str="float32")
|
||||
z = np.array(eval(f"{x_str} {op} {y_str}"))
|
||||
x_tri = to_triton(x)
|
||||
y_tri = to_triton(y)
|
||||
z_tri = to_triton(np.empty((1,), dtype=z.dtype))
|
||||
kernel[(1,)](z_tri, x_tri, y_tri)
|
||||
np.testing.assert_allclose(z, to_numpy(z_tri))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def stub(X, alpha, grid_0, grid_1, grid_2):
|
||||
tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2])
|
||||
def test_constexpr_shape():
|
||||
|
||||
@triton.jit
|
||||
def kernel(X):
|
||||
off = tl.arange(0, 128 + 128)
|
||||
tl.store(X + off, off)
|
||||
|
||||
# def test_dyn_par(cond=True, device='cuda'):
|
||||
# n_pids = 10
|
||||
# # pids = torch.arange(n_pids, device=device)
|
||||
# # alpha = 2.0
|
||||
# # x_ref = pids * alpha
|
||||
# x_tri = torch.full((10,), fill_value=-1., device=device)
|
||||
# # cond = torch.tensor([cond], device=device)
|
||||
# stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
|
||||
# print(x_tri)
|
||||
# # triton.testing.assert_almost_equal(x_ref, x_tri)
|
||||
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
||||
kernel[(1,)](x_tri)
|
||||
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
|
||||
|
@@ -390,12 +390,14 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return tuple(args)
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
# visit operand
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
rhs = rhs.value
|
||||
is_lhs_constexpr = isinstance(lhs, triton.language.constexpr)
|
||||
is_rhs_constexpr = isinstance(rhs, triton.language.constexpr)
|
||||
lhs = lhs.value if is_lhs_constexpr else lhs
|
||||
rhs = rhs.value if is_rhs_constexpr else rhs
|
||||
# get function name
|
||||
fn = {
|
||||
ast.Add: '__add__',
|
||||
ast.Sub: '__sub__',
|
||||
@@ -410,6 +412,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.BitOr: '__or__',
|
||||
ast.BitXor: '__xor__',
|
||||
}[type(node.op)]
|
||||
# return a new constexpr if both arg are constexprs
|
||||
if is_lhs_constexpr and is_rhs_constexpr:
|
||||
return triton.language.constexpr(getattr(lhs, fn)(rhs))
|
||||
# call operator
|
||||
if is_triton_tensor(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif is_triton_tensor(rhs):
|
||||
@@ -468,14 +474,16 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert len(node.ops) == 1
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.comparators[0])
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
rhs = rhs.value
|
||||
is_lhs_constexpr = isinstance(lhs, triton.language.constexpr)
|
||||
is_rhs_constexpr = isinstance(rhs, triton.language.constexpr)
|
||||
lhs = lhs.value if is_lhs_constexpr else lhs
|
||||
rhs = rhs.value if is_rhs_constexpr else rhs
|
||||
# handle `is`` and `is not``
|
||||
if type(node.ops[0]) == ast.Is:
|
||||
return triton.language.constexpr(lhs is rhs)
|
||||
if type(node.ops[0]) == ast.IsNot:
|
||||
return triton.language.constexpr(lhs is not rhs)
|
||||
# function name
|
||||
fn = {
|
||||
ast.Eq: '__eq__',
|
||||
ast.NotEq: '__ne__',
|
||||
@@ -484,29 +492,32 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.Gt: '__gt__',
|
||||
ast.GtE: '__ge__',
|
||||
}[type(node.ops[0])]
|
||||
# return a new constexpr if both arg are constexprs
|
||||
if is_lhs_constexpr and is_rhs_constexpr:
|
||||
return triton.language.constexpr(getattr(lhs, fn)(rhs))
|
||||
# call operator
|
||||
if is_triton_tensor(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif is_triton_tensor(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
return getattr(lhs, fn)(rhs)
|
||||
assert False
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
op = self.visit(node.operand)
|
||||
if type(node.op) == ast.Not:
|
||||
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
|
||||
return triton.language.constexpr(not op)
|
||||
if isinstance(op, triton.language.constexpr):
|
||||
op = op.value
|
||||
fn = {
|
||||
ast.USub: '__neg__',
|
||||
ast.UAdd: '__pos__',
|
||||
ast.Invert: '__invert__',
|
||||
}[type(node.op)]
|
||||
if is_triton_tensor(op):
|
||||
if isinstance(op, triton.language.constexpr):
|
||||
return triton.language.constexpr(getattr(op.value, fn)())
|
||||
assert is_triton_tensor(op)
|
||||
return getattr(op, fn)(_builder=self.builder)
|
||||
return getattr(op, fn)()
|
||||
|
||||
def visit_While(self, node):
|
||||
current_bb = self.builder.get_insert_block()
|
||||
@@ -656,6 +667,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
||||
for arg in args]
|
||||
ret = fn(*args, **kws)
|
||||
if isinstance(ret, (bool, int, float)):
|
||||
ret = triton.language.core.constexpr(ret)
|
||||
else:
|
||||
ret = triton.language.core._to_tensor(ret, self.builder)
|
||||
# special case: dynamic parallelism
|
||||
# in this case the core primitive returns a proxy
|
||||
# if isinstance(ret, triton.language.core.LaunchProxy):
|
||||
|
@@ -337,68 +337,6 @@ class constexpr:
|
||||
def __repr__(self) -> str:
|
||||
return f"constexpr[{self.value}]"
|
||||
|
||||
def __add__(self, other):
|
||||
return self.value + other.value
|
||||
|
||||
def __radd__(self, other):
|
||||
return other.value + self.value
|
||||
|
||||
def __sub__(self, other):
|
||||
return self.value - other.value
|
||||
|
||||
def __rsub__(self, other):
|
||||
return other.value - self.value
|
||||
|
||||
def __mul__(self, other):
|
||||
return self.value * other.value
|
||||
|
||||
def __rmul__(self, other):
|
||||
return other.value * self.value
|
||||
|
||||
def __truediv__(self, other):
|
||||
return self.value / other.value
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return other.value / self.value
|
||||
|
||||
def __floordiv__(self, other):
|
||||
return self.value // other.value
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
return other.value // self.value
|
||||
|
||||
#
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.value > other.value
|
||||
|
||||
def __rgt__(self, other):
|
||||
return other.value > self.value
|
||||
|
||||
def __ge__(self, other):
|
||||
return self.value >= other.value
|
||||
|
||||
def __rge__(self, other):
|
||||
return other.value >= self.value
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.value < other.value
|
||||
|
||||
def __rlt__(self, other):
|
||||
return other.value < self.value
|
||||
|
||||
def __le__(self, other):
|
||||
return self.value <= other.value
|
||||
|
||||
def __rle__(self, other):
|
||||
return other.value <= self.value
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.value == other.value
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.value != other.value
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.value)
|
||||
|
||||
@@ -496,6 +434,11 @@ class tensor:
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.mod(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __rmod__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.mod(other, self, _builder)
|
||||
|
||||
# unary operators
|
||||
@builtin
|
||||
def __neg__(self, _builder=None):
|
||||
@@ -564,6 +507,7 @@ class tensor:
|
||||
|
||||
@builtin
|
||||
def __rlt__(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.less_than(other, self, _builder)
|
||||
|
||||
# <=
|
||||
|
Reference in New Issue
Block a user