[FRONTEND] Handle for loops with negative constant steps (#896)

This commit is contained in:
Philippe Tillet
2022-11-20 11:37:38 +01:00
committed by GitHub
parent 6c5f646f4e
commit 4d64ffb5fe
2 changed files with 48 additions and 38 deletions

View File

@@ -472,8 +472,6 @@ class CodeGenerator(ast.NodeVisitor):
if type(node.op) == ast.Not:
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
return triton.language.constexpr(not op)
if isinstance(op, triton.language.constexpr):
op = op.value
fn = {
ast.USub: '__neg__',
ast.UAdd: '__pos__',
@@ -563,27 +561,30 @@ class CodeGenerator(ast.NodeVisitor):
iterator = self.visit(node.iter.func)
if iterator != self.builtins['range']:
raise RuntimeError('Only `range` iterator currently supported')
# static for loops: all iterator arguments are constexpr
# visit iterator arguments
# note: only `range` iterator is supported now
iter_args = [self.visit(arg) for arg in node.iter.args]
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
is_static = False
if static_unrolling:
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
if is_static:
iter_args = [arg.value for arg in iter_args]
range = iterator(*iter_args)
if len(range) <= 10:
for i in iterator(*iter_args):
# collect lower bound (lb), upper bound (ub), and step
lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0))
ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0])
step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
# static for loops: all iterator arguments are constexpr
if isinstance(lb, triton.language.constexpr) and \
isinstance(ub, triton.language.constexpr) and \
isinstance(step, triton.language.constexpr):
sta_range = iterator(lb.value, ub.value, step.value)
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
if static_unrolling and len(range) <= 10:
for i in sta_range:
self.lscope[node.target.id] = triton.language.constexpr(i)
self.visit_compound_statement(node.body)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
return
# collect lower bound (lb), upper bound (ub), and step
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
# handle negative constant step (not supported by scf.for in MLIR)
if isinstance(step, triton.language.constexpr) and step.value < 0:
step = triton.language.constexpr(-step.value)
lb, ub = ub, lb
# lb/ub/step might be constexpr, we need to cast them to tensor
lb = triton.language.core._to_tensor(lb, self.builder).handle
ub = triton.language.core._to_tensor(ub, self.builder).handle

View File

@@ -345,67 +345,76 @@ class constexpr:
return f"constexpr[{self.value}]"
def __add__(self, other):
return self.value + other.value
return constexpr(self.value + other.value)
def __radd__(self, other):
return other.value + self.value
return constexpr(other.value + self.value)
def __sub__(self, other):
return self.value - other.value
return constexpr(self.value - other.value)
def __rsub__(self, other):
return other.value - self.value
return constexpr(other.value - self.value)
def __mul__(self, other):
return self.value * other.value
return constexpr(self.value * other.value)
def __rmul__(self, other):
return other.value * self.value
return constexpr(other.value * self.value)
def __truediv__(self, other):
return self.value / other.value
return constexpr(self.value / other.value)
def __rtruediv__(self, other):
return other.value / self.value
return constexpr(other.value / self.value)
def __floordiv__(self, other):
return self.value // other.value
return constexpr(self.value // other.value)
def __rfloordiv__(self, other):
return other.value // self.value
return constexpr(other.value // self.value)
def __gt__(self, other):
return self.value > other.value
return constexpr(self.value > other.value)
def __rgt__(self, other):
return other.value > self.value
return constexpr(other.value > self.value)
def __ge__(self, other):
return self.value >= other.value
return constexpr(self.value >= other.value)
def __rge__(self, other):
return other.value >= self.value
return constexpr(other.value >= self.value)
def __lt__(self, other):
return self.value < other.value
return constexpr(self.value < other.value)
def __rlt__(self, other):
return other.value < self.value
return constexpr(other.value < self.value)
def __le__(self, other):
return self.value <= other.value
return constexpr(self.value <= other.value)
def __rle__(self, other):
return other.value <= self.value
return constexpr(other.value <= self.value)
def __eq__(self, other):
return self.value == other.value
return constexpr(self.value == other.value)
def __ne__(self, other):
return self.value != other.value
return constexpr(self.value != other.value)
def __bool__(self):
return bool(self.value)
return constexpr(bool(self.value))
def __neg__(self):
return constexpr(-self.value)
def __pos__(self):
return constexpr(+self.value)
def __invert__(self):
return constexpr(~self.value)
def __call__(self, *args, **kwds):
return self.value(*args, **kwds)