[FRONTEND] add an attr for masked load without explicit other (#55)

This commit is contained in:
Shintaro Iwasaki
2022-08-18 09:51:37 -07:00
committed by GitHub
parent fc58250a06
commit d69ce77b19
16 changed files with 71 additions and 54 deletions

View File

@@ -1306,7 +1306,8 @@ void init_triton_ir(py::module &&m) {
})
.def("create_masked_load",
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &mask,
mlir::Value &other, mlir::triton::CacheModifier cacheModifier,
std::optional<mlir::Value> &other,
mlir::triton::CacheModifier cacheModifier,
mlir::triton::EvictionPolicy evictionPolicy,
bool isVolatile) -> mlir::Value {
auto loc = self.getUnknownLoc();
@@ -1315,9 +1316,22 @@ void init_triton_ir(py::module &&m) {
mlir::Type elementType = ptrType.getElementType()
.dyn_cast<mlir::triton::PointerType>()
.getPointeeType();
return self.create<mlir::triton::LoadOp>(
loc, mlir::RankedTensorType::get(shape, elementType), ptrs,
mask, other, cacheModifier, evictionPolicy, isVolatile);
mlir::Type resultType =
mlir::RankedTensorType::get(shape, elementType);
if (other.has_value()) {
return self.create<mlir::triton::LoadOp>(
loc, resultType, ptrs, mask, other.value(), cacheModifier,
evictionPolicy, isVolatile, false);
} else {
mlir::Value dummy_other = self.create<mlir::arith::ConstantOp>(
loc, resultType,
mlir::DenseElementsAttr::get(resultType,
self.getZeroAttr(elementType)));
return self.create<mlir::triton::LoadOp>(
loc, mlir::RankedTensorType::get(shape, elementType), ptrs,
mask, dummy_other, cacheModifier, evictionPolicy, isVolatile,
true);
}
})
.def("create_masked_store",
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val,