[triton-mlir][BACKEND] Support masked load/store (#657)
This PR does - fix some bugs to support masked load/store, - refine frontend, and support the `and` and `or` syntax in mask(by extending the BoolOp in python ast.visitor), e.g. `tl.store(..., mask=offset<n and other_conditions)`, - add `arith.cmpI` and `arith.cmpF` op conversion in backend(required by mask), - add more test cases in vecadd.
This commit is contained in:
@@ -699,6 +699,28 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_Constant(self, node):
|
||||
return triton.language.constexpr(node.value)
|
||||
|
||||
def visit_BoolOp(self, node: ast.BoolOp):
|
||||
assert len(node.values) == 2
|
||||
lhs = self.visit(node.values[0])
|
||||
rhs = self.visit(node.values[1])
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
rhs = rhs.value
|
||||
|
||||
fn = {
|
||||
ast.And: 'logical_and',
|
||||
ast.Or: 'logical_or',
|
||||
}[type(node.op)]
|
||||
|
||||
if self.is_triton_tensor(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif self.is_triton_tensor(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
return getattr(lhs, fn)(rhs)
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
def visit_NameConstant(self, node):
|
||||
return triton.language.constexpr(node.value)
|
||||
|
Reference in New Issue
Block a user