[triton-mlir][BACKEND] Support masked load/store (#657)

This PR does

- fix some bugs to support masked load/store,
- refine frontend, and support the `and` and `or` syntax in mask(by
extending the BoolOp in python ast.visitor), e.g. `tl.store(...,
mask=offset<n and other_conditions)`,
- add `arith.cmpI` and `arith.cmpF` op conversion in backend(required by
mask),
- add more test cases in vecadd.
This commit is contained in:
Yan Chunwei
2022-10-10 13:29:53 +08:00
committed by GitHub
parent ccc5ab6ac9
commit 555f94f9b9
9 changed files with 396 additions and 74 deletions

View File

@@ -699,6 +699,28 @@ class CodeGenerator(ast.NodeVisitor):
def visit_Constant(self, node):
return triton.language.constexpr(node.value)
def visit_BoolOp(self, node: ast.BoolOp):
assert len(node.values) == 2
lhs = self.visit(node.values[0])
rhs = self.visit(node.values[1])
if isinstance(lhs, triton.language.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.constexpr):
rhs = rhs.value
fn = {
ast.And: 'logical_and',
ast.Or: 'logical_or',
}[type(node.op)]
if self.is_triton_tensor(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_tensor(rhs):
fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, _builder=self.builder)
else:
return getattr(lhs, fn)(rhs)
if sys.version_info < (3, 8):
def visit_NameConstant(self, node):
return triton.language.constexpr(node.value)