special encoding for broadcast

This commit is contained in:
Yan Da
2022-06-18 21:16:45 +08:00
parent 53cf93ce6a
commit 9d1b5e3f79
6 changed files with 248 additions and 72 deletions

View File

@@ -10,7 +10,7 @@ using namespace mlir::triton::gpu;
// parse an array of integers
static LogicalResult parseIntArrayAttr(AsmParser &parser,
const NamedAttribute &attr,
SmallVector<unsigned, 2> &res,
/*SmallVector<unsigned, 2>*/auto &res,
StringRef desc) {
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
if (!arrayAttr) {
@@ -36,8 +36,7 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
Attribute
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
static Attribute parseBlocked(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
@@ -51,28 +50,7 @@ TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
SmallVector<unsigned, 2> warpTileSize;
SmallVector<unsigned, 2> blockTileSize;
SmallVector<unsigned, 2> order;
// parse an array of integers
// auto parseIntArrayAttr = [&parser](const NamedAttribute &attr,
// SmallVector<unsigned, 2> &res,
// StringRef desc) -> LogicalResult {
// auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
// if (!arrayAttr) {
// parser.emitError(parser.getNameLoc(), "expected an array for ")
// << desc;
// return failure();
// }
// for (Attribute i : arrayAttr) {
// auto intAttr = i.dyn_cast<IntegerAttr>();
// if (!intAttr) {
// parser.emitError(parser.getNameLoc(), "expected an integer value in ")
// << desc;
// return failure();
// }
// res.push_back(intAttr.getUInt());
// }
// return success();
// };
SmallVector<unsigned, 2> broadcastAxis;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "threadTileSize") {
@@ -98,20 +76,39 @@ TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
threadTileSize,
warpTileSize,
blockTileSize,
order);
order,
broadcastAxis);
}
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
static void printBlocked(AsmPrinter &printer, auto *attr) {
printer << "<{"
<< "threadTileSize = [" << getThreadTileSize() << "]"
<< ", warpTileSize = [" << getWarpTileSize() << "]"
<< ", blockTileSize = [" << getBlockTileSize() << "]"
<< ", order = [" << getOrder() << "]"
<< "threadTileSize = [" << attr->getThreadTileSize() << "]"
<< ", warpTileSize = [" << attr->getWarpTileSize() << "]"
<< ", blockTileSize = [" << attr->getBlockTileSize() << "]"
<< ", order = [" << attr->getOrder() << "]"
<< ", broadcastAxis = [" << attr->getBroadcastAxis() << "]"
<< "}>";
}
Attribute
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
parseBlocked(parser, type);
}
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printBlocked(printer, this);
}
Attribute
TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
parseBlocked(parser, type);
}
void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const {
printBlocked(printer, this);
}
static Attribute parseMma(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
@@ -126,6 +123,7 @@ TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
SmallVector<unsigned, 2> shapePerTile;
SmallVector<unsigned, 2> repetitions;
SmallVector<unsigned, 2> contigPerThread;
SmallVector<unsigned, 2> broadcastAxis;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "fragmentPerWarp") {
@@ -159,18 +157,37 @@ TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
warpPerTile,
shapePerTile,
repetitions,
contigPerThread);
contigPerThread,
broadcastAxis);
}
static void printMma(AsmPrinter &printer, auto *attr) {
printer << "<{"
<< "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]"
<< ", shapePerWarp = [" << attr->getShapePerWarp() << "]"
<< ", warpPerTile = [" << attr->getWarpPerTile() << "]"
<< ", shapePerTile = [" << attr->getShapePerTile() << "]"
<< ", repetitions = [" << attr->getRepetitions() << "]"
<< ", contigPerThread = [" << attr->getContigPerThread() << "]"
<< "}>";
}
Attribute
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
return parseMma(parser, type);
}
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "fragmentPerWarp = [" << getFragmentPerWarp() << "]"
<< ", shapePerWarp = [" << getShapePerWarp() << "]"
<< ", warpPerTile = [" << getWarpPerTile() << "]"
<< ", shapePerTile = [" << getShapePerTile() << "]"
<< ", repetitions = [" << getRepetitions() << "]"
<< ", contigPerThread = [" << getContigPerThread() << "]"
<< "}>";
printMma(printer, this);
}
Attribute
TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
return parseMma(parser, type);
}
void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this);
}
Attribute

View File

@@ -1,9 +1,11 @@
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include <algorithm>
using namespace mlir;
using namespace mlir::triton::gpu;
//
// TypeConverter
@@ -41,6 +43,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank);
llvm::SmallVector<unsigned> broadcastAxis;
int remainingThreads = numThreads;
int remainingLanes = /*warp size*/32;
for (int64_t dim = 0; dim < rank; ++dim) {
@@ -53,7 +56,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// TODO: will we need repetition?
}
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
context, threadTileSize, warpTileSize, blockTileSize, order);
context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis);
return RankedTensorType::get(shape, elementType, encoding);
});
@@ -81,7 +84,6 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// NOTE: only for remapped values.
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -124,3 +126,98 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
});
}
// %dst = tt.broadcast %src
// =>
// %newSrc = convert_layout %src
// %bcst = tt.broadcast %newSrc
// %dst = convert_layout %bcst
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
int numThreads) {
// collect broadcasts
SmallVector<triton::BroadcastOp> broadcasts;
mod.walk([&](triton::BroadcastOp op) {
broadcasts.push_back(op);
});
BlockAndValueMapping mapping;
for (auto broadcast : broadcasts) {
OpBuilder builder(broadcast);
Value src = mapping.lookupOrDefault(broadcast.src());
Type originSrcType = src.getType();
Type originDstType = broadcast.getType();
auto originDstTensorType = originDstType.dyn_cast<RankedTensorType>();
unsigned dstRank = originDstTensorType.getRank();
// compute newSrcType & broadcastAxis
Type newSrcType;
SmallVector<unsigned> broadcastAxis;
bool isSrcScalar = false;
if (auto tensorType = originSrcType.dyn_cast<RankedTensorType>()) {
assert(tensorType.getRank() == dstRank &&
"src & dst should have same rank (verifier should catch this)");
for (unsigned ax = 0; ax < dstRank; ++ax)
if (tensorType.getShape()[ax] < originDstTensorType.getShape()[ax])
broadcastAxis.push_back(ax);
Attribute originSrcEnc = tensorType.getEncoding();
if (auto blockedEnc = originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
blockedEnc.getContext(),
blockedEnc.getThreadTileSize(),
blockedEnc.getWarpTileSize(),
blockedEnc.getBlockTileSize(),
blockedEnc.getOrder(),
broadcastAxis
);
newSrcType = RankedTensorType::get(
tensorType.getShape(),
tensorType.getElementType(),
newSrcEnc
);
} else
llvm_unreachable("src of broadcast should have blocked encoding");
} else {
for (unsigned ax = 0; ax < dstRank; ++ax)
broadcastAxis.push_back(ax);
newSrcType = originSrcType;
isSrcScalar = true;
}
// create new src
if (!isSrcScalar) // we don't need to convert layout for scalar values
src = builder.create<triton::gpu::ConvertLayoutOp>(
src.getLoc(), newSrcType, src
);
// create new broadcast
// compute new type (encoding)
auto originDstEnc = originDstTensorType.getEncoding()
.dyn_cast<TritonGPUBlockedEncodingAttr>();
auto newEnc = TritonGPUBlockedEncodingAttr::get(
originDstEnc.getContext(),
originDstEnc.getThreadTileSize(),
originDstEnc.getWarpTileSize(),
originDstEnc.getBlockTileSize(),
originDstEnc.getOrder(),
broadcastAxis
);
auto newType = RankedTensorType::get(
originDstTensorType.getShape(),
originDstTensorType.getElementType(),
newEnc
);
Value newBroadcast = builder.create<triton::BroadcastOp>(
broadcast.getLoc(), newType, src
);
// we don't want to change the encoding of the result
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
broadcast.getLoc(), originDstType, newBroadcast
);
broadcast.replaceAllUsesWith(newDst);
mapping.map(broadcast, newDst);
}
return success();
}