[FRONTEND] Handle for loops with negative constant steps (#896)

This commit is contained in:
Philippe Tillet
2022-11-20 11:37:38 +01:00
committed by GitHub
parent 6c5f646f4e
commit 4d64ffb5fe
2 changed files with 48 additions and 38 deletions

View File

@@ -472,8 +472,6 @@ class CodeGenerator(ast.NodeVisitor):
if type(node.op) == ast.Not: if type(node.op) == ast.Not:
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
return triton.language.constexpr(not op) return triton.language.constexpr(not op)
if isinstance(op, triton.language.constexpr):
op = op.value
fn = { fn = {
ast.USub: '__neg__', ast.USub: '__neg__',
ast.UAdd: '__pos__', ast.UAdd: '__pos__',
@@ -563,27 +561,30 @@ class CodeGenerator(ast.NodeVisitor):
iterator = self.visit(node.iter.func) iterator = self.visit(node.iter.func)
if iterator != self.builtins['range']: if iterator != self.builtins['range']:
raise RuntimeError('Only `range` iterator currently supported') raise RuntimeError('Only `range` iterator currently supported')
# static for loops: all iterator arguments are constexpr # visit iterator arguments
# note: only `range` iterator is supported now
iter_args = [self.visit(arg) for arg in node.iter.args] iter_args = [self.visit(arg) for arg in node.iter.args]
# collect lower bound (lb), upper bound (ub), and step
lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0))
ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0])
step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
# static for loops: all iterator arguments are constexpr
if isinstance(lb, triton.language.constexpr) and \
isinstance(ub, triton.language.constexpr) and \
isinstance(step, triton.language.constexpr):
sta_range = iterator(lb.value, ub.value, step.value)
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False) static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
is_static = False if static_unrolling and len(range) <= 10:
if static_unrolling: for i in sta_range:
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
if is_static:
iter_args = [arg.value for arg in iter_args]
range = iterator(*iter_args)
if len(range) <= 10:
for i in iterator(*iter_args):
self.lscope[node.target.id] = triton.language.constexpr(i) self.lscope[node.target.id] = triton.language.constexpr(i)
self.visit_compound_statement(node.body) self.visit_compound_statement(node.body)
for stmt in node.orelse: for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt) ast.NodeVisitor.generic_visit(self, stmt)
return return
# handle negative constant step (not supported by scf.for in MLIR)
# collect lower bound (lb), upper bound (ub), and step if isinstance(step, triton.language.constexpr) and step.value < 0:
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)) step = triton.language.constexpr(-step.value)
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]) lb, ub = ub, lb
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
# lb/ub/step might be constexpr, we need to cast them to tensor # lb/ub/step might be constexpr, we need to cast them to tensor
lb = triton.language.core._to_tensor(lb, self.builder).handle lb = triton.language.core._to_tensor(lb, self.builder).handle
ub = triton.language.core._to_tensor(ub, self.builder).handle ub = triton.language.core._to_tensor(ub, self.builder).handle

View File

@@ -345,67 +345,76 @@ class constexpr:
return f"constexpr[{self.value}]" return f"constexpr[{self.value}]"
def __add__(self, other): def __add__(self, other):
return self.value + other.value return constexpr(self.value + other.value)
def __radd__(self, other): def __radd__(self, other):
return other.value + self.value return constexpr(other.value + self.value)
def __sub__(self, other): def __sub__(self, other):
return self.value - other.value return constexpr(self.value - other.value)
def __rsub__(self, other): def __rsub__(self, other):
return other.value - self.value return constexpr(other.value - self.value)
def __mul__(self, other): def __mul__(self, other):
return self.value * other.value return constexpr(self.value * other.value)
def __rmul__(self, other): def __rmul__(self, other):
return other.value * self.value return constexpr(other.value * self.value)
def __truediv__(self, other): def __truediv__(self, other):
return self.value / other.value return constexpr(self.value / other.value)
def __rtruediv__(self, other): def __rtruediv__(self, other):
return other.value / self.value return constexpr(other.value / self.value)
def __floordiv__(self, other): def __floordiv__(self, other):
return self.value // other.value return constexpr(self.value // other.value)
def __rfloordiv__(self, other): def __rfloordiv__(self, other):
return other.value // self.value return constexpr(other.value // self.value)
def __gt__(self, other): def __gt__(self, other):
return self.value > other.value return constexpr(self.value > other.value)
def __rgt__(self, other): def __rgt__(self, other):
return other.value > self.value return constexpr(other.value > self.value)
def __ge__(self, other): def __ge__(self, other):
return self.value >= other.value return constexpr(self.value >= other.value)
def __rge__(self, other): def __rge__(self, other):
return other.value >= self.value return constexpr(other.value >= self.value)
def __lt__(self, other): def __lt__(self, other):
return self.value < other.value return constexpr(self.value < other.value)
def __rlt__(self, other): def __rlt__(self, other):
return other.value < self.value return constexpr(other.value < self.value)
def __le__(self, other): def __le__(self, other):
return self.value <= other.value return constexpr(self.value <= other.value)
def __rle__(self, other): def __rle__(self, other):
return other.value <= self.value return constexpr(other.value <= self.value)
def __eq__(self, other): def __eq__(self, other):
return self.value == other.value return constexpr(self.value == other.value)
def __ne__(self, other): def __ne__(self, other):
return self.value != other.value return constexpr(self.value != other.value)
def __bool__(self): def __bool__(self):
return bool(self.value) return constexpr(bool(self.value))
def __neg__(self):
return constexpr(-self.value)
def __pos__(self):
return constexpr(+self.value)
def __invert__(self):
return constexpr(~self.value)
def __call__(self, *args, **kwds): def __call__(self, *args, **kwds):
return self.value(*args, **kwds) return self.value(*args, **kwds)