[TritonIR] Make mask operand optional (#74)

This commit is contained in:
Shintaro Iwasaki
2022-08-22 22:00:17 -07:00
committed by GitHub
parent de2dd04c8a
commit 0ebef11c77
14 changed files with 113 additions and 102 deletions

View File

@@ -1442,21 +1442,9 @@ void init_triton_ir(py::module &&m) {
mlir::triton::EvictionPolicy evictionPolicy,
bool isVolatile) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto ptrType = ptrs.getType().dyn_cast<mlir::RankedTensorType>();
std::vector<int64_t> shape = ptrType.getShape();
mlir::Type elementType = ptrType.getElementType()
.dyn_cast<mlir::triton::PointerType>()
.getPointeeType();
mlir::Type resultType =
mlir::RankedTensorType::get(shape, elementType);
if (other.has_value()) {
return self.create<mlir::triton::LoadOp>(
loc, resultType, ptrs, mask, other.value(), cacheModifier,
evictionPolicy, isVolatile);
} else {
return self.create<mlir::triton::LoadOp>(
loc, ptrs, mask, cacheModifier, evictionPolicy, isVolatile);
}
return self.create<mlir::triton::LoadOp>(
loc, ptrs, mask, other.value_or(mlir::Value()), cacheModifier,
evictionPolicy, isVolatile);
})
.def("create_masked_store",
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val,