Files
triton/lib/Dialect/Triton/IR/Ops.cpp

347 lines
12 KiB
C++
Raw Permalink Normal View History

#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OperationSupport.h"
namespace mlir {
namespace triton {
// Type inference
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return i1Type;
}
static Type getI32SameShape(Type type) {
auto i32Type = IntegerType::get(type.getContext(), 32);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i32Type,
tensorType.getEncoding());
return i32Type;
}
static Type getPointerTypeSameShape(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Type elementType = tensorType.getElementType();
auto shape = tensorType.getShape();
PointerType ptrType = PointerType::get(elementType, 1);
return RankedTensorType::get(shape, ptrType, tensorType.getEncoding());
} else {
return PointerType::get(type, 1);
}
}
// Parser & printer for assembly forms
ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> allOperands;
Type resultTypes[1];
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(resultTypes[0]))
return failure();
result.addTypes(resultTypes);
SmallVector<Type> operandTypes;
operandTypes.push_back(getPointerTypeSameShape(resultTypes[0])); // ptr
int hasMask = 0, hasOther = 0;
if (allOperands.size() >= 2) {
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
hasMask = 1;
}
if (allOperands.size() >= 3) {
operandTypes.push_back(resultTypes[0]); // other
hasOther = 1;
}
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
// Deduce operand_segment_sizes from the number of the operands.
auto operand_segment_sizesAttrName =
LoadOp::operand_segment_sizesAttrName(result.name);
result.addAttribute(
operand_segment_sizesAttrName,
parser.getBuilder().getI32VectorAttr({1, hasMask, hasOther}));
return success();
}
void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) {
printer << " ";
printer << loadOp.getOperation()->getOperands();
// "operand_segment_sizes" can be deduced, so we don't print it.
printer.printOptionalAttrDict(loadOp->getAttrs(),
{loadOp.operand_segment_sizesAttrName()});
printer << " : ";
printer.printStrippedAttrOrType(loadOp.result().getType());
}
ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> allOperands;
Type valueType;
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(valueType))
return failure();
SmallVector<Type> operandTypes;
operandTypes.push_back(getPointerTypeSameShape(valueType)); // ptr
operandTypes.push_back(valueType); // value
if (allOperands.size() >= 3)
operandTypes.push_back(getI1SameShape(valueType)); // mask
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
return success();
}
void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
printer << " ";
printer << storeOp.getOperation()->getOperands();
printer.printOptionalAttrDict(storeOp->getAttrs(), /*elidedAttrs=*/{});
printer << " : ";
printer.printStrippedAttrOrType(storeOp.value().getType());
}
} // namespace triton
} // namespace mlir
#define GET_OP_CLASSES
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
// enum attribute definitions
#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc"
namespace mlir {
namespace triton {
//-- FpToFpOp --
bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs,
::mlir::TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
auto srcEltType = inputs.front();
auto dstEltType = outputs.front();
auto srcTensorType = srcEltType.dyn_cast<mlir::RankedTensorType>();
auto dstTensorType = dstEltType.dyn_cast<mlir::RankedTensorType>();
if (srcTensorType && dstTensorType) {
srcEltType = srcTensorType.getElementType();
dstEltType = dstTensorType.getElementType();
}
// Check whether fp8 <=> fp16, bf16, f32, f64
// Make `srcEltType` always the fp8 side
if (dstEltType.dyn_cast<mlir::triton::Float8Type>())
std::swap(srcEltType, dstEltType);
if (!srcEltType.dyn_cast<mlir::triton::Float8Type>())
return false;
return dstEltType.isF16() || dstEltType.isBF16() || dstEltType.isF32() ||
dstEltType.isF64();
}
//-- StoreOp --
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value) {
StoreOp::build(builder, state, ptr, value, mlir::Value());
}
//-- LoadOp --
static Type getLoadOpResultType(::mlir::OpBuilder &builder, Type ptrType) {
auto ptrTensorType = ptrType.dyn_cast<RankedTensorType>();
if (!ptrTensorType)
return ptrType.cast<PointerType>().getPointeeType();
auto shape = ptrTensorType.getShape();
Type elementType =
ptrTensorType.getElementType().cast<PointerType>().getPointeeType();
return RankedTensorType::get(shape, elementType);
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
LoadOp::build(builder, state, ptr, mlir::Value(), mlir::Value(), cache, evict,
isVolatile);
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
LoadOp::build(builder, state, ptr, mask, mlir::Value(), cache, evict,
isVolatile);
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
Type resultType = getLoadOpResultType(builder, ptr.getType());
state.addOperands(ptr);
if (mask) {
state.addOperands(mask);
if (other) {
state.addOperands(other);
}
}
state.addAttribute(
operand_segment_sizesAttrName(state.name),
builder.getI32VectorAttr({1, (mask ? 1 : 0), (other ? 1 : 0)}));
state.addAttribute(
cacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
state.addAttribute(
evictAttrName(state.name),
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name),
builder.getBoolAttr(isVolatile));
state.addTypes({resultType});
}
//-- DotOp --
mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// type is the same as the accumulator
auto accTy = operands[2].getType().cast<RankedTensorType>();
inferredReturnTypes.push_back(accTy);
// verify encodings
auto aEnc = operands[0].getType().cast<RankedTensorType>().getEncoding();
auto bEnc = operands[1].getType().cast<RankedTensorType>().getEncoding();
auto retEnc = accTy.getEncoding();
if (aEnc) {
assert(bEnc);
Dialect &dialect = aEnc.getDialect();
auto interface = dyn_cast<DialectInferLayoutInterface>(&dialect);
if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed())
return mlir::failure();
if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed())
return mlir::failure();
}
return mlir::success();
}
//-- ReduceOp --
mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// infer shape
Value arg = operands[0];
auto argTy = arg.getType().cast<RankedTensorType>();
auto argEltTy = argTy.getElementType();
auto i32Ty = IntegerType::get(argEltTy.getContext(), 32);
auto redOp =
attributes.get("redOp").cast<mlir::triton::RedOpAttr>().getValue();
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
auto retEltTy = withIndex ? i32Ty : argEltTy;
auto retShape = argTy.getShape().vec();
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
retShape.erase(retShape.begin() + axis);
if (retShape.empty()) {
// 0d-tensor -> scalar
inferredReturnTypes.push_back(retEltTy);
} else {
// nd-tensor where n >= 1
// infer encoding
Attribute argEncoding = argTy.getEncoding();
Attribute retEncoding;
if (argEncoding) {
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface =
dyn_cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
.failed()) {
llvm::report_fatal_error("failed to infer layout for ReduceOp");
return mlir::failure();
}
}
// create type
inferredReturnTypes.push_back(
RankedTensorType::get(retShape, retEltTy, retEncoding));
}
return mlir::success();
}
bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) {
return redOp == mlir::triton::RedOp::ARGMIN ||
redOp == mlir::triton::RedOp::ARGMAX ||
redOp == mlir::triton::RedOp::ARGUMIN ||
redOp == mlir::triton::RedOp::ARGUMAX ||
redOp == mlir::triton::RedOp::ARGFMIN ||
redOp == mlir::triton::RedOp::ARGFMAX;
}
//-- SplatOp --
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
return ret;
}
//-- ExpandDimsOp --
mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
MLIRContext *context, Optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// infer shape
auto arg = operands[0];
auto argTy = arg.getType().cast<RankedTensorType>();
auto retShape = argTy.getShape().vec();
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
retShape.insert(retShape.begin() + axis, 1);
// infer encoding
Attribute argEncoding = argTy.getEncoding();
Attribute retEncoding;
if (argEncoding) {
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc)
.failed())
return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp");
}
// create type
auto argEltTy = argTy.getElementType();
inferredReturnTypes.push_back(
RankedTensorType::get(retShape, argEltTy, retEncoding));
return mlir::success();
}
//-- BroadcastOp --
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
auto value = constOperand.getValue();
if (auto denseElemsAttr = value.dyn_cast<DenseElementsAttr>()) {
if (!denseElemsAttr.isSplat())
return {};
return SplatElementsAttr::get(shapedType,
denseElemsAttr.getSplatValue<Attribute>());
} else if (value.getType().isIntOrIndexOrFloat()) {
return SplatElementsAttr::get(shapedType, value);
} else {
return {};
}
}
} // namespace triton
} // namespace mlir