[FRONTEND] Handle for loops with negative constant steps (#896)
This commit is contained in:
@@ -472,8 +472,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
if type(node.op) == ast.Not:
|
if type(node.op) == ast.Not:
|
||||||
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
|
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
|
||||||
return triton.language.constexpr(not op)
|
return triton.language.constexpr(not op)
|
||||||
if isinstance(op, triton.language.constexpr):
|
|
||||||
op = op.value
|
|
||||||
fn = {
|
fn = {
|
||||||
ast.USub: '__neg__',
|
ast.USub: '__neg__',
|
||||||
ast.UAdd: '__pos__',
|
ast.UAdd: '__pos__',
|
||||||
@@ -563,27 +561,30 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
iterator = self.visit(node.iter.func)
|
iterator = self.visit(node.iter.func)
|
||||||
if iterator != self.builtins['range']:
|
if iterator != self.builtins['range']:
|
||||||
raise RuntimeError('Only `range` iterator currently supported')
|
raise RuntimeError('Only `range` iterator currently supported')
|
||||||
# static for loops: all iterator arguments are constexpr
|
# visit iterator arguments
|
||||||
|
# note: only `range` iterator is supported now
|
||||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||||
|
# collect lower bound (lb), upper bound (ub), and step
|
||||||
|
lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0))
|
||||||
|
ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0])
|
||||||
|
step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
|
||||||
|
# static for loops: all iterator arguments are constexpr
|
||||||
|
if isinstance(lb, triton.language.constexpr) and \
|
||||||
|
isinstance(ub, triton.language.constexpr) and \
|
||||||
|
isinstance(step, triton.language.constexpr):
|
||||||
|
sta_range = iterator(lb.value, ub.value, step.value)
|
||||||
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
|
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
|
||||||
is_static = False
|
if static_unrolling and len(range) <= 10:
|
||||||
if static_unrolling:
|
for i in sta_range:
|
||||||
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
|
|
||||||
if is_static:
|
|
||||||
iter_args = [arg.value for arg in iter_args]
|
|
||||||
range = iterator(*iter_args)
|
|
||||||
if len(range) <= 10:
|
|
||||||
for i in iterator(*iter_args):
|
|
||||||
self.lscope[node.target.id] = triton.language.constexpr(i)
|
self.lscope[node.target.id] = triton.language.constexpr(i)
|
||||||
self.visit_compound_statement(node.body)
|
self.visit_compound_statement(node.body)
|
||||||
for stmt in node.orelse:
|
for stmt in node.orelse:
|
||||||
ast.NodeVisitor.generic_visit(self, stmt)
|
ast.NodeVisitor.generic_visit(self, stmt)
|
||||||
return
|
return
|
||||||
|
# handle negative constant step (not supported by scf.for in MLIR)
|
||||||
# collect lower bound (lb), upper bound (ub), and step
|
if isinstance(step, triton.language.constexpr) and step.value < 0:
|
||||||
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
|
step = triton.language.constexpr(-step.value)
|
||||||
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
|
lb, ub = ub, lb
|
||||||
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
|
|
||||||
# lb/ub/step might be constexpr, we need to cast them to tensor
|
# lb/ub/step might be constexpr, we need to cast them to tensor
|
||||||
lb = triton.language.core._to_tensor(lb, self.builder).handle
|
lb = triton.language.core._to_tensor(lb, self.builder).handle
|
||||||
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
||||||
|
@@ -345,67 +345,76 @@ class constexpr:
|
|||||||
return f"constexpr[{self.value}]"
|
return f"constexpr[{self.value}]"
|
||||||
|
|
||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
return self.value + other.value
|
return constexpr(self.value + other.value)
|
||||||
|
|
||||||
def __radd__(self, other):
|
def __radd__(self, other):
|
||||||
return other.value + self.value
|
return constexpr(other.value + self.value)
|
||||||
|
|
||||||
def __sub__(self, other):
|
def __sub__(self, other):
|
||||||
return self.value - other.value
|
return constexpr(self.value - other.value)
|
||||||
|
|
||||||
def __rsub__(self, other):
|
def __rsub__(self, other):
|
||||||
return other.value - self.value
|
return constexpr(other.value - self.value)
|
||||||
|
|
||||||
def __mul__(self, other):
|
def __mul__(self, other):
|
||||||
return self.value * other.value
|
return constexpr(self.value * other.value)
|
||||||
|
|
||||||
def __rmul__(self, other):
|
def __rmul__(self, other):
|
||||||
return other.value * self.value
|
return constexpr(other.value * self.value)
|
||||||
|
|
||||||
def __truediv__(self, other):
|
def __truediv__(self, other):
|
||||||
return self.value / other.value
|
return constexpr(self.value / other.value)
|
||||||
|
|
||||||
def __rtruediv__(self, other):
|
def __rtruediv__(self, other):
|
||||||
return other.value / self.value
|
return constexpr(other.value / self.value)
|
||||||
|
|
||||||
def __floordiv__(self, other):
|
def __floordiv__(self, other):
|
||||||
return self.value // other.value
|
return constexpr(self.value // other.value)
|
||||||
|
|
||||||
def __rfloordiv__(self, other):
|
def __rfloordiv__(self, other):
|
||||||
return other.value // self.value
|
return constexpr(other.value // self.value)
|
||||||
|
|
||||||
def __gt__(self, other):
|
def __gt__(self, other):
|
||||||
return self.value > other.value
|
return constexpr(self.value > other.value)
|
||||||
|
|
||||||
def __rgt__(self, other):
|
def __rgt__(self, other):
|
||||||
return other.value > self.value
|
return constexpr(other.value > self.value)
|
||||||
|
|
||||||
def __ge__(self, other):
|
def __ge__(self, other):
|
||||||
return self.value >= other.value
|
return constexpr(self.value >= other.value)
|
||||||
|
|
||||||
def __rge__(self, other):
|
def __rge__(self, other):
|
||||||
return other.value >= self.value
|
return constexpr(other.value >= self.value)
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
return self.value < other.value
|
return constexpr(self.value < other.value)
|
||||||
|
|
||||||
def __rlt__(self, other):
|
def __rlt__(self, other):
|
||||||
return other.value < self.value
|
return constexpr(other.value < self.value)
|
||||||
|
|
||||||
def __le__(self, other):
|
def __le__(self, other):
|
||||||
return self.value <= other.value
|
return constexpr(self.value <= other.value)
|
||||||
|
|
||||||
def __rle__(self, other):
|
def __rle__(self, other):
|
||||||
return other.value <= self.value
|
return constexpr(other.value <= self.value)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.value == other.value
|
return constexpr(self.value == other.value)
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
return self.value != other.value
|
return constexpr(self.value != other.value)
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
return bool(self.value)
|
return constexpr(bool(self.value))
|
||||||
|
|
||||||
|
def __neg__(self):
|
||||||
|
return constexpr(-self.value)
|
||||||
|
|
||||||
|
def __pos__(self):
|
||||||
|
return constexpr(+self.value)
|
||||||
|
|
||||||
|
def __invert__(self):
|
||||||
|
return constexpr(~self.value)
|
||||||
|
|
||||||
def __call__(self, *args, **kwds):
|
def __call__(self, *args, **kwds):
|
||||||
return self.value(*args, **kwds)
|
return self.value(*args, **kwds)
|
||||||
|
Reference in New Issue
Block a user