[FRONTEND] Improved constexpr handling (#493)

This commit is contained in:
Philippe Tillet
2022-04-12 00:02:54 -07:00
committed by GitHub
parent 14b0fd4cfb
commit 76bfac9f15
4 changed files with 70 additions and 96 deletions

View File

@@ -1,4 +1,3 @@
#include <bits/types/clock_t.h>
#include <string> #include <string>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>

View File

@@ -1032,28 +1032,44 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
kernel[(1, )](value, x) kernel[(1, )](value, x)
else: else:
kernel[(1, )](value, x) kernel[(1, )](value, x)
# -------------------------
# test dynamic parallelism
# -------------------------
@triton.jit # ----------------
def mult(x, alpha): # test constexpr
tl.store(x + tl.program_id(0), alpha) # ----------------
@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>'])
@pytest.mark.parametrize("is_lhs_constexpr", [False, True])
@pytest.mark.parametrize("is_rhs_constexpr", [True, False])
def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr):
@triton.jit
def kernel(Z, X, Y):
x = tl.load(X)
y = tl.load(Y)
z = GENERATE_TEST_HERE
tl.store(Z, z)
x_str = "3.14" if is_lhs_constexpr else "x"
y_str = "4.13" if is_rhs_constexpr else "y"
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"})
x = numpy_random((1,), dtype_str="float32")
y = numpy_random((1,), dtype_str="float32")
z = np.array(eval(f"{x_str} {op} {y_str}"))
x_tri = to_triton(x)
y_tri = to_triton(y)
z_tri = to_triton(np.empty((1,), dtype=z.dtype))
kernel[(1,)](z_tri, x_tri, y_tri)
np.testing.assert_allclose(z, to_numpy(z_tri))
@triton.jit def test_constexpr_shape():
def stub(X, alpha, grid_0, grid_1, grid_2):
tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2])
@triton.jit
def kernel(X):
off = tl.arange(0, 128 + 128)
tl.store(X + off, off)
# def test_dyn_par(cond=True, device='cuda'): x_tri = to_triton(np.empty((256, ), dtype=np.int32))
# n_pids = 10 kernel[(1,)](x_tri)
# # pids = torch.arange(n_pids, device=device) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
# # alpha = 2.0
# # x_ref = pids * alpha
# x_tri = torch.full((10,), fill_value=-1., device=device)
# # cond = torch.tensor([cond], device=device)
# stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
# print(x_tri)
# # triton.testing.assert_almost_equal(x_ref, x_tri)

View File

@@ -390,12 +390,14 @@ class CodeGenerator(ast.NodeVisitor):
return tuple(args) return tuple(args)
def visit_BinOp(self, node): def visit_BinOp(self, node):
# visit operand
lhs = self.visit(node.left) lhs = self.visit(node.left)
rhs = self.visit(node.right) rhs = self.visit(node.right)
if isinstance(lhs, triton.language.constexpr): is_lhs_constexpr = isinstance(lhs, triton.language.constexpr)
lhs = lhs.value is_rhs_constexpr = isinstance(rhs, triton.language.constexpr)
if isinstance(rhs, triton.language.constexpr): lhs = lhs.value if is_lhs_constexpr else lhs
rhs = rhs.value rhs = rhs.value if is_rhs_constexpr else rhs
# get function name
fn = { fn = {
ast.Add: '__add__', ast.Add: '__add__',
ast.Sub: '__sub__', ast.Sub: '__sub__',
@@ -410,6 +412,10 @@ class CodeGenerator(ast.NodeVisitor):
ast.BitOr: '__or__', ast.BitOr: '__or__',
ast.BitXor: '__xor__', ast.BitXor: '__xor__',
}[type(node.op)] }[type(node.op)]
# return a new constexpr if both arg are constexprs
if is_lhs_constexpr and is_rhs_constexpr:
return triton.language.constexpr(getattr(lhs, fn)(rhs))
# call operator
if is_triton_tensor(lhs): if is_triton_tensor(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder) return getattr(lhs, fn)(rhs, _builder=self.builder)
elif is_triton_tensor(rhs): elif is_triton_tensor(rhs):
@@ -468,14 +474,16 @@ class CodeGenerator(ast.NodeVisitor):
assert len(node.ops) == 1 assert len(node.ops) == 1
lhs = self.visit(node.left) lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0]) rhs = self.visit(node.comparators[0])
if isinstance(lhs, triton.language.constexpr): is_lhs_constexpr = isinstance(lhs, triton.language.constexpr)
lhs = lhs.value is_rhs_constexpr = isinstance(rhs, triton.language.constexpr)
if isinstance(rhs, triton.language.constexpr): lhs = lhs.value if is_lhs_constexpr else lhs
rhs = rhs.value rhs = rhs.value if is_rhs_constexpr else rhs
# handle `is`` and `is not``
if type(node.ops[0]) == ast.Is: if type(node.ops[0]) == ast.Is:
return triton.language.constexpr(lhs is rhs) return triton.language.constexpr(lhs is rhs)
if type(node.ops[0]) == ast.IsNot: if type(node.ops[0]) == ast.IsNot:
return triton.language.constexpr(lhs is not rhs) return triton.language.constexpr(lhs is not rhs)
# function name
fn = { fn = {
ast.Eq: '__eq__', ast.Eq: '__eq__',
ast.NotEq: '__ne__', ast.NotEq: '__ne__',
@@ -484,29 +492,32 @@ class CodeGenerator(ast.NodeVisitor):
ast.Gt: '__gt__', ast.Gt: '__gt__',
ast.GtE: '__ge__', ast.GtE: '__ge__',
}[type(node.ops[0])] }[type(node.ops[0])]
# return a new constexpr if both arg are constexprs
if is_lhs_constexpr and is_rhs_constexpr:
return triton.language.constexpr(getattr(lhs, fn)(rhs))
# call operator
if is_triton_tensor(lhs): if is_triton_tensor(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder) return getattr(lhs, fn)(rhs, _builder=self.builder)
elif is_triton_tensor(rhs): elif is_triton_tensor(rhs):
fn = fn[:2] + 'r' + fn[2:] fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, _builder=self.builder) return getattr(rhs, fn)(lhs, _builder=self.builder)
else: else:
return getattr(lhs, fn)(rhs) assert False
def visit_UnaryOp(self, node): def visit_UnaryOp(self, node):
op = self.visit(node.operand) op = self.visit(node.operand)
if type(node.op) == ast.Not: if type(node.op) == ast.Not:
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
return triton.language.constexpr(not op) return triton.language.constexpr(not op)
if isinstance(op, triton.language.constexpr):
op = op.value
fn = { fn = {
ast.USub: '__neg__', ast.USub: '__neg__',
ast.UAdd: '__pos__', ast.UAdd: '__pos__',
ast.Invert: '__invert__', ast.Invert: '__invert__',
}[type(node.op)] }[type(node.op)]
if is_triton_tensor(op): if isinstance(op, triton.language.constexpr):
return triton.language.constexpr(getattr(op.value, fn)())
assert is_triton_tensor(op)
return getattr(op, fn)(_builder=self.builder) return getattr(op, fn)(_builder=self.builder)
return getattr(op, fn)()
def visit_While(self, node): def visit_While(self, node):
current_bb = self.builder.get_insert_block() current_bb = self.builder.get_insert_block()
@@ -656,6 +667,10 @@ class CodeGenerator(ast.NodeVisitor):
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
for arg in args] for arg in args]
ret = fn(*args, **kws) ret = fn(*args, **kws)
if isinstance(ret, (bool, int, float)):
ret = triton.language.core.constexpr(ret)
else:
ret = triton.language.core._to_tensor(ret, self.builder)
# special case: dynamic parallelism # special case: dynamic parallelism
# in this case the core primitive returns a proxy # in this case the core primitive returns a proxy
# if isinstance(ret, triton.language.core.LaunchProxy): # if isinstance(ret, triton.language.core.LaunchProxy):

View File

@@ -337,68 +337,6 @@ class constexpr:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"constexpr[{self.value}]" return f"constexpr[{self.value}]"
def __add__(self, other):
return self.value + other.value
def __radd__(self, other):
return other.value + self.value
def __sub__(self, other):
return self.value - other.value
def __rsub__(self, other):
return other.value - self.value
def __mul__(self, other):
return self.value * other.value
def __rmul__(self, other):
return other.value * self.value
def __truediv__(self, other):
return self.value / other.value
def __rtruediv__(self, other):
return other.value / self.value
def __floordiv__(self, other):
return self.value // other.value
def __rfloordiv__(self, other):
return other.value // self.value
#
def __gt__(self, other):
return self.value > other.value
def __rgt__(self, other):
return other.value > self.value
def __ge__(self, other):
return self.value >= other.value
def __rge__(self, other):
return other.value >= self.value
def __lt__(self, other):
return self.value < other.value
def __rlt__(self, other):
return other.value < self.value
def __le__(self, other):
return self.value <= other.value
def __rle__(self, other):
return other.value <= self.value
def __eq__(self, other):
return self.value == other.value
def __ne__(self, other):
return self.value != other.value
def __bool__(self): def __bool__(self):
return bool(self.value) return bool(self.value)
@@ -496,6 +434,11 @@ class tensor:
other = _to_tensor(other, _builder) other = _to_tensor(other, _builder)
return semantic.mod(self, other, _builder) return semantic.mod(self, other, _builder)
@builtin
def __rmod__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.mod(other, self, _builder)
# unary operators # unary operators
@builtin @builtin
def __neg__(self, _builder=None): def __neg__(self, _builder=None):
@@ -564,6 +507,7 @@ class tensor:
@builtin @builtin
def __rlt__(self, other, _builder=None): def __rlt__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.less_than(other, self, _builder) return semantic.less_than(other, self, _builder)
# <= # <=