[FRONTEND] add an attr for masked load without explicit other (#55)

This commit is contained in:
Shintaro Iwasaki
2022-08-18 09:51:37 -07:00
committed by GitHub
parent fc58250a06
commit d69ce77b19
16 changed files with 71 additions and 54 deletions

View File

@@ -1306,7 +1306,8 @@ void init_triton_ir(py::module &&m) {
})
.def("create_masked_load",
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &mask,
mlir::Value &other, mlir::triton::CacheModifier cacheModifier,
std::optional<mlir::Value> &other,
mlir::triton::CacheModifier cacheModifier,
mlir::triton::EvictionPolicy evictionPolicy,
bool isVolatile) -> mlir::Value {
auto loc = self.getUnknownLoc();
@@ -1315,9 +1316,22 @@ void init_triton_ir(py::module &&m) {
mlir::Type elementType = ptrType.getElementType()
.dyn_cast<mlir::triton::PointerType>()
.getPointeeType();
return self.create<mlir::triton::LoadOp>(
loc, mlir::RankedTensorType::get(shape, elementType), ptrs,
mask, other, cacheModifier, evictionPolicy, isVolatile);
mlir::Type resultType =
mlir::RankedTensorType::get(shape, elementType);
if (other.has_value()) {
return self.create<mlir::triton::LoadOp>(
loc, resultType, ptrs, mask, other.value(), cacheModifier,
evictionPolicy, isVolatile, false);
} else {
mlir::Value dummy_other = self.create<mlir::arith::ConstantOp>(
loc, resultType,
mlir::DenseElementsAttr::get(resultType,
self.getZeroAttr(elementType)));
return self.create<mlir::triton::LoadOp>(
loc, mlir::RankedTensorType::get(shape, elementType), ptrs,
mask, dummy_other, cacheModifier, evictionPolicy, isVolatile,
true);
}
})
.def("create_masked_store",
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val,

View File

@@ -706,23 +706,17 @@ def load(ptr: tl.tensor,
else:
dst_ty = elt_ty
if not mask and not other:
if not mask:
if other:
raise ValueError("`other` cannot be provided without `mask`")
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile),
dst_ty)
if not mask:
raise ValueError("`other` cannot be provided without `mask`")
if not other:
other_ir = ir.undef.get(elt_ty.to_ir(builder))
if ptr.type.is_block():
other_ir = builder.create_splat(other_ir, ptr.type.get_block_shapes())
other = tl.tensor(other_ir, dst_ty)
return tl.tensor(builder.create_masked_load(ptr.handle,
mask.handle,
other.handle,
cache, eviction, is_volatile),
dst_ty)
else:
return tl.tensor(builder.create_masked_load(ptr.handle,
mask.handle,
other.handle if other else None,
cache, eviction, is_volatile),
dst_ty)
def store(ptr: tl.tensor,