#include "triton/ir/Dialect.h" #include "triton/ir/Types.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OperationSupport.h" namespace mlir { namespace triton { // Type inference static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); if (auto tensorType = type.dyn_cast()) return RankedTensorType::get(tensorType.getShape(), i1Type); return Type(); } static Type getI32SameShape(Type type) { auto i32Type = IntegerType::get(type.getContext(), 32); if (auto tensorType = type.dyn_cast()) return RankedTensorType::get(tensorType.getShape(), i32Type); return Type(); } static Type getPointerTypeFromTensor(Type type) { if (auto tensorType = type.dyn_cast()) { Type elementType = tensorType.getElementType(); auto shape = tensorType.getShape(); PointerType ptrType = PointerType::get(elementType, 1); return RankedTensorType::get(shape, ptrType); } return Type(); } } } #define GET_OP_CLASSES #include "triton/ir/Ops.cpp.inc" // enum attribute definitions #include "triton/ir/OpsEnums.cpp.inc" namespace mlir { namespace triton { //-- StoreOp -- // Default mask void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value) { TensorType ptrType = ptr.getType().dyn_cast(); auto shape = ptrType.getShape(); ::mlir::Value mask = builder.create( ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()), DenseIntElementsAttr::get( RankedTensorType::get(shape, builder.getI1Type()), true ) ); state.addOperands(ptr); state.addOperands(value); state.addOperands(mask); } //-- LoadOp -- void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) { TensorType ptrType = ptr.getType().dyn_cast(); Type elementType = ptrType.getElementType().dyn_cast().getPointeeType(); auto shape = ptrType.getShape(); // mask ::mlir::Value mask = builder.create( ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()), DenseIntElementsAttr::get( RankedTensorType::get(shape, builder.getI1Type()), true ) ); // other Type resultType = RankedTensorType::get(shape, elementType); ::mlir::Value other = builder.create( ptr.getLoc(), resultType, DenseElementsAttr::get( resultType, builder.getZeroAttr(elementType) ) ); state.addOperands(ptr); state.addOperands(mask); state.addOperands(other); state.addAttribute(cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache)); state.addAttribute(evictAttrName(state.name), ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict)); state.addAttribute(isVolatileAttrName(state.name), builder.getBoolAttr(isVolatile)); state.addTypes({resultType}); } } // namespace triton } // namespace mlir