[TritonIR] Make mask operand optional (#74)
This commit is contained in:
@@ -1442,21 +1442,9 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::triton::EvictionPolicy evictionPolicy,
|
||||
bool isVolatile) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto ptrType = ptrs.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
std::vector<int64_t> shape = ptrType.getShape();
|
||||
mlir::Type elementType = ptrType.getElementType()
|
||||
.dyn_cast<mlir::triton::PointerType>()
|
||||
.getPointeeType();
|
||||
mlir::Type resultType =
|
||||
mlir::RankedTensorType::get(shape, elementType);
|
||||
if (other.has_value()) {
|
||||
return self.create<mlir::triton::LoadOp>(
|
||||
loc, resultType, ptrs, mask, other.value(), cacheModifier,
|
||||
evictionPolicy, isVolatile);
|
||||
} else {
|
||||
return self.create<mlir::triton::LoadOp>(
|
||||
loc, ptrs, mask, cacheModifier, evictionPolicy, isVolatile);
|
||||
}
|
||||
return self.create<mlir::triton::LoadOp>(
|
||||
loc, ptrs, mask, other.value_or(mlir::Value()), cacheModifier,
|
||||
evictionPolicy, isVolatile);
|
||||
})
|
||||
.def("create_masked_store",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val,
|
||||
|
Reference in New Issue
Block a user