Files
triton/lib/Dialect/Triton/IR/Ops.cpp

121 lines
4.1 KiB
C++
Raw Normal View History

2022-04-27 19:28:21 +08:00
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
2022-03-17 20:40:55 +08:00
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OperationSupport.h"
namespace mlir {
namespace triton {
// Type inference
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
2022-07-26 17:25:03 -07:00
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return Type();
}
static Type getI32SameShape(Type type) {
auto i32Type = IntegerType::get(type.getContext(), 32);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
2022-07-26 17:25:03 -07:00
return RankedTensorType::get(tensorType.getShape(), i32Type,
tensorType.getEncoding());
return Type();
}
static Type getPointerTypeFromTensor(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Type elementType = tensorType.getElementType();
auto shape = tensorType.getShape();
PointerType ptrType = PointerType::get(elementType, 1);
2022-05-04 14:54:31 +08:00
return RankedTensorType::get(shape, ptrType, tensorType.getEncoding());
}
return Type();
}
2022-07-26 17:25:03 -07:00
} // namespace triton
} // namespace mlir
2022-03-17 20:40:55 +08:00
#define GET_OP_CLASSES
2022-04-27 19:41:07 +08:00
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
2022-03-17 20:40:55 +08:00
// enum attribute definitions
2022-04-27 19:41:07 +08:00
#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc"
2022-03-17 20:40:55 +08:00
namespace mlir {
namespace triton {
//-- StoreOp --
// Default mask
2022-07-26 17:25:03 -07:00
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value) {
2022-03-17 20:40:55 +08:00
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
auto shape = ptrType.getShape();
::mlir::Value mask = builder.create<arith::ConstantOp>(
2022-07-26 17:25:03 -07:00
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true));
2022-03-17 20:40:55 +08:00
state.addOperands(ptr);
state.addOperands(value);
state.addOperands(mask);
}
//-- LoadOp --
2022-07-26 17:25:03 -07:00
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
2022-03-17 20:40:55 +08:00
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
2022-07-26 17:25:03 -07:00
Type elementType =
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
2022-03-17 20:40:55 +08:00
auto shape = ptrType.getShape();
// mask
::mlir::Value mask = builder.create<arith::ConstantOp>(
2022-07-26 17:25:03 -07:00
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true));
2022-03-17 20:40:55 +08:00
// other
Type resultType = RankedTensorType::get(shape, elementType);
::mlir::Value other = builder.create<arith::ConstantOp>(
2022-07-26 17:25:03 -07:00
ptr.getLoc(), resultType,
DenseElementsAttr::get(resultType, builder.getZeroAttr(elementType)));
2022-03-17 20:40:55 +08:00
state.addOperands(ptr);
state.addOperands(mask);
state.addOperands(other);
2022-07-26 17:25:03 -07:00
state.addAttribute(
cacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
state.addAttribute(
evictAttrName(state.name),
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name),
builder.getBoolAttr(isVolatile));
2022-03-17 20:40:55 +08:00
state.addTypes({resultType});
}
2022-04-27 13:45:56 +08:00
//-- DotOp --
//-- SplatOp --
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
return ret;
}
2022-04-27 13:45:56 +08:00
//-- BroadcastOp --
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
return SplatElementsAttr::get(shapedType, {constOperand.getValue()});
}
2022-03-17 20:40:55 +08:00
} // namespace triton
} // namespace mlir