[FRONTEND] Handle for loops with negative constant steps (#896)
This commit is contained in:
@@ -472,8 +472,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if type(node.op) == ast.Not:
|
||||
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
|
||||
return triton.language.constexpr(not op)
|
||||
if isinstance(op, triton.language.constexpr):
|
||||
op = op.value
|
||||
fn = {
|
||||
ast.USub: '__neg__',
|
||||
ast.UAdd: '__pos__',
|
||||
@@ -563,27 +561,30 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
iterator = self.visit(node.iter.func)
|
||||
if iterator != self.builtins['range']:
|
||||
raise RuntimeError('Only `range` iterator currently supported')
|
||||
# static for loops: all iterator arguments are constexpr
|
||||
# visit iterator arguments
|
||||
# note: only `range` iterator is supported now
|
||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
|
||||
is_static = False
|
||||
if static_unrolling:
|
||||
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
|
||||
if is_static:
|
||||
iter_args = [arg.value for arg in iter_args]
|
||||
range = iterator(*iter_args)
|
||||
if len(range) <= 10:
|
||||
for i in iterator(*iter_args):
|
||||
# collect lower bound (lb), upper bound (ub), and step
|
||||
lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0))
|
||||
ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0])
|
||||
step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
|
||||
# static for loops: all iterator arguments are constexpr
|
||||
if isinstance(lb, triton.language.constexpr) and \
|
||||
isinstance(ub, triton.language.constexpr) and \
|
||||
isinstance(step, triton.language.constexpr):
|
||||
sta_range = iterator(lb.value, ub.value, step.value)
|
||||
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
|
||||
if static_unrolling and len(range) <= 10:
|
||||
for i in sta_range:
|
||||
self.lscope[node.target.id] = triton.language.constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
for stmt in node.orelse:
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
return
|
||||
|
||||
# collect lower bound (lb), upper bound (ub), and step
|
||||
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
|
||||
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
|
||||
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
|
||||
# handle negative constant step (not supported by scf.for in MLIR)
|
||||
if isinstance(step, triton.language.constexpr) and step.value < 0:
|
||||
step = triton.language.constexpr(-step.value)
|
||||
lb, ub = ub, lb
|
||||
# lb/ub/step might be constexpr, we need to cast them to tensor
|
||||
lb = triton.language.core._to_tensor(lb, self.builder).handle
|
||||
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
||||
|
@@ -345,67 +345,76 @@ class constexpr:
|
||||
return f"constexpr[{self.value}]"
|
||||
|
||||
def __add__(self, other):
|
||||
return self.value + other.value
|
||||
return constexpr(self.value + other.value)
|
||||
|
||||
def __radd__(self, other):
|
||||
return other.value + self.value
|
||||
return constexpr(other.value + self.value)
|
||||
|
||||
def __sub__(self, other):
|
||||
return self.value - other.value
|
||||
return constexpr(self.value - other.value)
|
||||
|
||||
def __rsub__(self, other):
|
||||
return other.value - self.value
|
||||
return constexpr(other.value - self.value)
|
||||
|
||||
def __mul__(self, other):
|
||||
return self.value * other.value
|
||||
return constexpr(self.value * other.value)
|
||||
|
||||
def __rmul__(self, other):
|
||||
return other.value * self.value
|
||||
return constexpr(other.value * self.value)
|
||||
|
||||
def __truediv__(self, other):
|
||||
return self.value / other.value
|
||||
return constexpr(self.value / other.value)
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return other.value / self.value
|
||||
return constexpr(other.value / self.value)
|
||||
|
||||
def __floordiv__(self, other):
|
||||
return self.value // other.value
|
||||
return constexpr(self.value // other.value)
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
return other.value // self.value
|
||||
return constexpr(other.value // self.value)
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.value > other.value
|
||||
return constexpr(self.value > other.value)
|
||||
|
||||
def __rgt__(self, other):
|
||||
return other.value > self.value
|
||||
return constexpr(other.value > self.value)
|
||||
|
||||
def __ge__(self, other):
|
||||
return self.value >= other.value
|
||||
return constexpr(self.value >= other.value)
|
||||
|
||||
def __rge__(self, other):
|
||||
return other.value >= self.value
|
||||
return constexpr(other.value >= self.value)
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.value < other.value
|
||||
return constexpr(self.value < other.value)
|
||||
|
||||
def __rlt__(self, other):
|
||||
return other.value < self.value
|
||||
return constexpr(other.value < self.value)
|
||||
|
||||
def __le__(self, other):
|
||||
return self.value <= other.value
|
||||
return constexpr(self.value <= other.value)
|
||||
|
||||
def __rle__(self, other):
|
||||
return other.value <= self.value
|
||||
return constexpr(other.value <= self.value)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.value == other.value
|
||||
return constexpr(self.value == other.value)
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.value != other.value
|
||||
return constexpr(self.value != other.value)
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.value)
|
||||
return constexpr(bool(self.value))
|
||||
|
||||
def __neg__(self):
|
||||
return constexpr(-self.value)
|
||||
|
||||
def __pos__(self):
|
||||
return constexpr(+self.value)
|
||||
|
||||
def __invert__(self):
|
||||
return constexpr(~self.value)
|
||||
|
||||
def __call__(self, *args, **kwds):
|
||||
return self.value(*args, **kwds)
|
||||
|
Reference in New Issue
Block a user