Fix OpBuilder

This commit is contained in:
Yan Da
2022-04-07 20:01:31 +08:00
parent 6b4da6f016
commit 040a2b6c75
5 changed files with 41 additions and 20 deletions

View File

@@ -33,7 +33,8 @@ void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, :
}
//-- LoadOp --
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr) {
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr,
::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
Type elementType = ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
auto shape = ptrType.getShape();
@@ -57,6 +58,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::
state.addOperands(ptr);
state.addOperands(mask);
state.addOperands(other);
state.addAttribute(cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
state.addAttribute(evictAttrName(state.name), ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name), builder.getBoolAttr(isVolatile));
state.addTypes({resultType});
}