[FRONTEND] add an attr for masked load without explicit other (#55)
This commit is contained in:
@@ -1306,7 +1306,8 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
.def("create_masked_load",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &mask,
|
||||
mlir::Value &other, mlir::triton::CacheModifier cacheModifier,
|
||||
std::optional<mlir::Value> &other,
|
||||
mlir::triton::CacheModifier cacheModifier,
|
||||
mlir::triton::EvictionPolicy evictionPolicy,
|
||||
bool isVolatile) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
@@ -1315,9 +1316,22 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::Type elementType = ptrType.getElementType()
|
||||
.dyn_cast<mlir::triton::PointerType>()
|
||||
.getPointeeType();
|
||||
return self.create<mlir::triton::LoadOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, elementType), ptrs,
|
||||
mask, other, cacheModifier, evictionPolicy, isVolatile);
|
||||
mlir::Type resultType =
|
||||
mlir::RankedTensorType::get(shape, elementType);
|
||||
if (other.has_value()) {
|
||||
return self.create<mlir::triton::LoadOp>(
|
||||
loc, resultType, ptrs, mask, other.value(), cacheModifier,
|
||||
evictionPolicy, isVolatile, false);
|
||||
} else {
|
||||
mlir::Value dummy_other = self.create<mlir::arith::ConstantOp>(
|
||||
loc, resultType,
|
||||
mlir::DenseElementsAttr::get(resultType,
|
||||
self.getZeroAttr(elementType)));
|
||||
return self.create<mlir::triton::LoadOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, elementType), ptrs,
|
||||
mask, dummy_other, cacheModifier, evictionPolicy, isVolatile,
|
||||
true);
|
||||
}
|
||||
})
|
||||
.def("create_masked_store",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val,
|
||||
|
@@ -706,23 +706,17 @@ def load(ptr: tl.tensor,
|
||||
else:
|
||||
dst_ty = elt_ty
|
||||
|
||||
if not mask and not other:
|
||||
if not mask:
|
||||
if other:
|
||||
raise ValueError("`other` cannot be provided without `mask`")
|
||||
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile),
|
||||
dst_ty)
|
||||
if not mask:
|
||||
raise ValueError("`other` cannot be provided without `mask`")
|
||||
|
||||
if not other:
|
||||
other_ir = ir.undef.get(elt_ty.to_ir(builder))
|
||||
if ptr.type.is_block():
|
||||
other_ir = builder.create_splat(other_ir, ptr.type.get_block_shapes())
|
||||
other = tl.tensor(other_ir, dst_ty)
|
||||
|
||||
return tl.tensor(builder.create_masked_load(ptr.handle,
|
||||
mask.handle,
|
||||
other.handle,
|
||||
cache, eviction, is_volatile),
|
||||
dst_ty)
|
||||
else:
|
||||
return tl.tensor(builder.create_masked_load(ptr.handle,
|
||||
mask.handle,
|
||||
other.handle if other else None,
|
||||
cache, eviction, is_volatile),
|
||||
dst_ty)
|
||||
|
||||
|
||||
def store(ptr: tl.tensor,
|
||||
|
Reference in New Issue
Block a user