diff --git a/python/triton/compiler.py b/python/triton/compiler.py index b98675068..e6c595ea7 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -472,8 +472,6 @@ class CodeGenerator(ast.NodeVisitor): if type(node.op) == ast.Not: assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" return triton.language.constexpr(not op) - if isinstance(op, triton.language.constexpr): - op = op.value fn = { ast.USub: '__neg__', ast.UAdd: '__pos__', @@ -563,27 +561,30 @@ class CodeGenerator(ast.NodeVisitor): iterator = self.visit(node.iter.func) if iterator != self.builtins['range']: raise RuntimeError('Only `range` iterator currently supported') - # static for loops: all iterator arguments are constexpr + # visit iterator arguments + # note: only `range` iterator is supported now iter_args = [self.visit(arg) for arg in node.iter.args] - static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False) - is_static = False - if static_unrolling: - is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args]) - if is_static: - iter_args = [arg.value for arg in iter_args] - range = iterator(*iter_args) - if len(range) <= 10: - for i in iterator(*iter_args): + # collect lower bound (lb), upper bound (ub), and step + lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0)) + ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0]) + step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) + # static for loops: all iterator arguments are constexpr + if isinstance(lb, triton.language.constexpr) and \ + isinstance(ub, triton.language.constexpr) and \ + isinstance(step, triton.language.constexpr): + sta_range = iterator(lb.value, ub.value, step.value) + static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False) + if static_unrolling and len(range) <= 10: + for i in sta_range: self.lscope[node.target.id] = triton.language.constexpr(i) self.visit_compound_statement(node.body) for stmt in node.orelse: ast.NodeVisitor.generic_visit(self, stmt) return - - # collect lower bound (lb), upper bound (ub), and step - lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)) - ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]) - step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)) + # handle negative constant step (not supported by scf.for in MLIR) + if isinstance(step, triton.language.constexpr) and step.value < 0: + step = triton.language.constexpr(-step.value) + lb, ub = ub, lb # lb/ub/step might be constexpr, we need to cast them to tensor lb = triton.language.core._to_tensor(lb, self.builder).handle ub = triton.language.core._to_tensor(ub, self.builder).handle diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 6066ee4f5..a5e8166e6 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -345,67 +345,76 @@ class constexpr: return f"constexpr[{self.value}]" def __add__(self, other): - return self.value + other.value + return constexpr(self.value + other.value) def __radd__(self, other): - return other.value + self.value + return constexpr(other.value + self.value) def __sub__(self, other): - return self.value - other.value + return constexpr(self.value - other.value) def __rsub__(self, other): - return other.value - self.value + return constexpr(other.value - self.value) def __mul__(self, other): - return self.value * other.value + return constexpr(self.value * other.value) def __rmul__(self, other): - return other.value * self.value + return constexpr(other.value * self.value) def __truediv__(self, other): - return self.value / other.value + return constexpr(self.value / other.value) def __rtruediv__(self, other): - return other.value / self.value + return constexpr(other.value / self.value) def __floordiv__(self, other): - return self.value // other.value + return constexpr(self.value // other.value) def __rfloordiv__(self, other): - return other.value // self.value + return constexpr(other.value // self.value) def __gt__(self, other): - return self.value > other.value + return constexpr(self.value > other.value) def __rgt__(self, other): - return other.value > self.value + return constexpr(other.value > self.value) def __ge__(self, other): - return self.value >= other.value + return constexpr(self.value >= other.value) def __rge__(self, other): - return other.value >= self.value + return constexpr(other.value >= self.value) def __lt__(self, other): - return self.value < other.value + return constexpr(self.value < other.value) def __rlt__(self, other): - return other.value < self.value + return constexpr(other.value < self.value) def __le__(self, other): - return self.value <= other.value + return constexpr(self.value <= other.value) def __rle__(self, other): - return other.value <= self.value + return constexpr(other.value <= self.value) def __eq__(self, other): - return self.value == other.value + return constexpr(self.value == other.value) def __ne__(self, other): - return self.value != other.value + return constexpr(self.value != other.value) def __bool__(self): - return bool(self.value) + return constexpr(bool(self.value)) + + def __neg__(self): + return constexpr(-self.value) + + def __pos__(self): + return constexpr(+self.value) + + def __invert__(self): + return constexpr(~self.value) def __call__(self, *args, **kwds): return self.value(*args, **kwds)