special encoding for broadcast
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include <algorithm>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton::gpu;
|
||||
|
||||
//
|
||||
// TypeConverter
|
||||
@@ -41,6 +43,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
|
||||
llvm::SmallVector<unsigned> blockTileSize(rank);
|
||||
llvm::SmallVector<unsigned> order(rank);
|
||||
llvm::SmallVector<unsigned> broadcastAxis;
|
||||
int remainingThreads = numThreads;
|
||||
int remainingLanes = /*warp size*/32;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
@@ -53,7 +56,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// TODO: will we need repetition?
|
||||
}
|
||||
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
context, threadTileSize, warpTileSize, blockTileSize, order);
|
||||
context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis);
|
||||
return RankedTensorType::get(shape, elementType, encoding);
|
||||
});
|
||||
|
||||
@@ -81,7 +84,6 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// NOTE: only for remapped values.
|
||||
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
assert(inputs.size() == 1);
|
||||
llvm_unreachable("Not implemented");
|
||||
return llvm::None;
|
||||
});
|
||||
@@ -124,3 +126,98 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
// %dst = tt.broadcast %src
|
||||
// =>
|
||||
// %newSrc = convert_layout %src
|
||||
// %bcst = tt.broadcast %newSrc
|
||||
// %dst = convert_layout %bcst
|
||||
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
|
||||
int numThreads) {
|
||||
// collect broadcasts
|
||||
SmallVector<triton::BroadcastOp> broadcasts;
|
||||
mod.walk([&](triton::BroadcastOp op) {
|
||||
broadcasts.push_back(op);
|
||||
});
|
||||
|
||||
BlockAndValueMapping mapping;
|
||||
for (auto broadcast : broadcasts) {
|
||||
OpBuilder builder(broadcast);
|
||||
Value src = mapping.lookupOrDefault(broadcast.src());
|
||||
Type originSrcType = src.getType();
|
||||
Type originDstType = broadcast.getType();
|
||||
auto originDstTensorType = originDstType.dyn_cast<RankedTensorType>();
|
||||
unsigned dstRank = originDstTensorType.getRank();
|
||||
|
||||
// compute newSrcType & broadcastAxis
|
||||
Type newSrcType;
|
||||
SmallVector<unsigned> broadcastAxis;
|
||||
bool isSrcScalar = false;
|
||||
if (auto tensorType = originSrcType.dyn_cast<RankedTensorType>()) {
|
||||
assert(tensorType.getRank() == dstRank &&
|
||||
"src & dst should have same rank (verifier should catch this)");
|
||||
for (unsigned ax = 0; ax < dstRank; ++ax)
|
||||
if (tensorType.getShape()[ax] < originDstTensorType.getShape()[ax])
|
||||
broadcastAxis.push_back(ax);
|
||||
|
||||
Attribute originSrcEnc = tensorType.getEncoding();
|
||||
if (auto blockedEnc = originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
|
||||
blockedEnc.getContext(),
|
||||
blockedEnc.getThreadTileSize(),
|
||||
blockedEnc.getWarpTileSize(),
|
||||
blockedEnc.getBlockTileSize(),
|
||||
blockedEnc.getOrder(),
|
||||
broadcastAxis
|
||||
);
|
||||
newSrcType = RankedTensorType::get(
|
||||
tensorType.getShape(),
|
||||
tensorType.getElementType(),
|
||||
newSrcEnc
|
||||
);
|
||||
} else
|
||||
llvm_unreachable("src of broadcast should have blocked encoding");
|
||||
} else {
|
||||
for (unsigned ax = 0; ax < dstRank; ++ax)
|
||||
broadcastAxis.push_back(ax);
|
||||
newSrcType = originSrcType;
|
||||
isSrcScalar = true;
|
||||
}
|
||||
|
||||
// create new src
|
||||
if (!isSrcScalar) // we don't need to convert layout for scalar values
|
||||
src = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
src.getLoc(), newSrcType, src
|
||||
);
|
||||
|
||||
// create new broadcast
|
||||
// compute new type (encoding)
|
||||
auto originDstEnc = originDstTensorType.getEncoding()
|
||||
.dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto newEnc = TritonGPUBlockedEncodingAttr::get(
|
||||
originDstEnc.getContext(),
|
||||
originDstEnc.getThreadTileSize(),
|
||||
originDstEnc.getWarpTileSize(),
|
||||
originDstEnc.getBlockTileSize(),
|
||||
originDstEnc.getOrder(),
|
||||
broadcastAxis
|
||||
);
|
||||
auto newType = RankedTensorType::get(
|
||||
originDstTensorType.getShape(),
|
||||
originDstTensorType.getElementType(),
|
||||
newEnc
|
||||
);
|
||||
Value newBroadcast = builder.create<triton::BroadcastOp>(
|
||||
broadcast.getLoc(), newType, src
|
||||
);
|
||||
// we don't want to change the encoding of the result
|
||||
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
broadcast.getLoc(), originDstType, newBroadcast
|
||||
);
|
||||
|
||||
broadcast.replaceAllUsesWith(newDst);
|
||||
mapping.map(broadcast, newDst);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
Reference in New Issue
Block a user