[TritonGPU] Improved documentation and semantics of layout encodings (#30)
This commit is contained in:
@@ -30,13 +30,28 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||
return success();
|
||||
};
|
||||
|
||||
static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
|
||||
unsigned &value, StringRef desc) {
|
||||
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
|
||||
return failure();
|
||||
}
|
||||
value = intAttr.getUInt();
|
||||
return success();
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attribute methods
|
||||
//===----------------------------------------------------------------------===//
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
||||
|
||||
static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Blocked Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
@@ -46,32 +61,30 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
SmallVector<unsigned, 2> threadTileSize;
|
||||
SmallVector<unsigned, 2> warpTileSize;
|
||||
SmallVector<unsigned, 2> blockTileSize;
|
||||
SmallVector<unsigned, 2> sizePerThread;
|
||||
SmallVector<unsigned, 2> threadsPerWarp;
|
||||
SmallVector<unsigned, 2> warpsPerCTA;
|
||||
SmallVector<unsigned, 2> order;
|
||||
SmallVector<unsigned, 2> broadcastAxis;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "threadTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size")
|
||||
if (attr.getName() == "sizePerThread") {
|
||||
if (parseIntArrayAttr(parser, attr, sizePerThread,
|
||||
"number of elements per thread")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "warpTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size")
|
||||
} else if (attr.getName() == "threadsPerWarp") {
|
||||
if (parseIntArrayAttr(parser, attr, threadsPerWarp,
|
||||
"number of threads per warp")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "blockTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size")
|
||||
} else if (attr.getName() == "warpsPerCTA") {
|
||||
if (parseIntArrayAttr(parser, attr, warpsPerCTA,
|
||||
"number of warps per CTA")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "order") {
|
||||
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "broadcastAxis") {
|
||||
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis")
|
||||
.failed())
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
<< attr.getName().strref();
|
||||
@@ -80,39 +93,23 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUBlockedEncodingAttr>(
|
||||
parser.getContext(), threadTileSize, warpTileSize, blockTileSize, order,
|
||||
broadcastAxis);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static void printBlocked(AsmPrinter &printer, const T *attr) {
|
||||
printer << "<{"
|
||||
<< "threadTileSize = [" << attr->getThreadTileSize() << "]"
|
||||
<< ", warpTileSize = [" << attr->getWarpTileSize() << "]"
|
||||
<< ", blockTileSize = [" << attr->getBlockTileSize() << "]"
|
||||
<< ", order = [" << attr->getOrder() << "]"
|
||||
<< ", broadcastAxis = [" << attr->getBroadcastAxis() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return parseBlocked(parser, type);
|
||||
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||
}
|
||||
|
||||
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printBlocked(printer, this);
|
||||
printer << "<{"
|
||||
<< "sizePerThread = [" << getSizePerThread() << "]"
|
||||
<< ", threadsPerWarp = [" << getThreadsPerWarp() << "]"
|
||||
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]"
|
||||
<< ", order = [" << getOrder() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser,
|
||||
Type type) {
|
||||
return parseBlocked(parser, type);
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MMA encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printBlocked(printer, this);
|
||||
}
|
||||
|
||||
static Attribute parseMma(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
DictionaryAttr dict;
|
||||
@@ -121,76 +118,34 @@ static Attribute parseMma(AsmParser &parser, Type type) {
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
SmallVector<unsigned, 2> fragmentPerWarp;
|
||||
SmallVector<unsigned, 2> shapePerWarp;
|
||||
SmallVector<unsigned, 2> warpPerTile;
|
||||
SmallVector<unsigned, 2> shapePerTile;
|
||||
SmallVector<unsigned, 2> repetitions;
|
||||
SmallVector<unsigned, 2> contigPerThread;
|
||||
SmallVector<unsigned, 2> broadcastAxis;
|
||||
unsigned version = 0;
|
||||
SmallVector<unsigned, 2> warpsPerCTA;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "fragmentPerWarp") {
|
||||
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp")
|
||||
.failed())
|
||||
if (attr.getName() == "version") {
|
||||
if (parseUInt(parser, attr, version, "version").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "shapePerWarp") {
|
||||
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp")
|
||||
.failed())
|
||||
}
|
||||
if (attr.getName() == "warpsPerCTA") {
|
||||
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "warpPerTile") {
|
||||
if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "shapePerTile") {
|
||||
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "repetitions") {
|
||||
if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "contigPerThread") {
|
||||
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread")
|
||||
.failed())
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
<< attr.getName().strref();
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUMmaEncodingAttr>(
|
||||
parser.getContext(), fragmentPerWarp, shapePerWarp, warpPerTile,
|
||||
shapePerTile, repetitions, contigPerThread, broadcastAxis);
|
||||
}
|
||||
|
||||
template <class T> static void printMma(AsmPrinter &printer, T *attr) {
|
||||
printer << "<{"
|
||||
<< "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]"
|
||||
<< ", shapePerWarp = [" << attr->getShapePerWarp() << "]"
|
||||
<< ", warpPerTile = [" << attr->getWarpPerTile() << "]"
|
||||
<< ", shapePerTile = [" << attr->getShapePerTile() << "]"
|
||||
<< ", repetitions = [" << attr->getRepetitions() << "]"
|
||||
<< ", contigPerThread = [" << attr->getContigPerThread() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return parseMma(parser, type);
|
||||
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
|
||||
version, warpsPerCTA);
|
||||
}
|
||||
|
||||
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printMma(printer, this);
|
||||
printer << "<{"
|
||||
<< "version = " << getVersion() << ", "
|
||||
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser,
|
||||
Type type) {
|
||||
return parseMma(parser, type);
|
||||
}
|
||||
|
||||
void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printMma(printer, this);
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
@@ -207,26 +162,15 @@ Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
unsigned maxPhase = 0;
|
||||
SmallVector<unsigned, 2> order;
|
||||
|
||||
auto parseUInt = [&parser](const NamedAttribute &attr, unsigned &value,
|
||||
StringRef desc) -> LogicalResult {
|
||||
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
|
||||
return failure();
|
||||
}
|
||||
value = intAttr.getUInt();
|
||||
return success();
|
||||
};
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "vec") {
|
||||
if (parseUInt(attr, vec, "vec").failed())
|
||||
if (parseUInt(parser, attr, vec, "vec").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "perPhase") {
|
||||
if (parseUInt(attr, perPhase, "perPhase").failed())
|
||||
if (parseUInt(parser, attr, perPhase, "perPhase").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "maxPhase") {
|
||||
if (parseUInt(attr, maxPhase, "maxPhase").failed())
|
||||
if (parseUInt(parser, attr, maxPhase, "maxPhase").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "order") {
|
||||
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
||||
@@ -250,6 +194,10 @@ void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ASM Interface (i.e.: alias)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
|
||||
public:
|
||||
using OpAsmDialectInterface::OpAsmDialectInterface;
|
||||
@@ -257,72 +205,18 @@ public:
|
||||
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
|
||||
if (auto mmaAttr = attr.dyn_cast<TritonGPUMmaEncodingAttr>()) {
|
||||
os << "mma";
|
||||
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto mmaMulticastAttr =
|
||||
attr.dyn_cast<TritonGPUMmaMulticastEncodingAttr>()) {
|
||||
os << "mma_multicast";
|
||||
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto sharedAttr = attr.dyn_cast<TritonGPUSharedEncodingAttr>()) {
|
||||
os << "shared";
|
||||
TritonGPUOpAsmInterface::printShared(sharedAttr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto blockedAttr =
|
||||
attr.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
os << "blocked";
|
||||
TritonGPUOpAsmInterface::printBlocked(blockedAttr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto blockedMulticastAttr =
|
||||
attr.dyn_cast<TritonGPUBlockedMulticastEncodingAttr>()) {
|
||||
os << "blocked_multicast";
|
||||
TritonGPUOpAsmInterface::printBlocked(blockedMulticastAttr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
}
|
||||
OpAsmDialectInterface::getAlias(attr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
}
|
||||
|
||||
private:
|
||||
static void printMma(const TritonGPUMmaEncodingAttr &attr, raw_ostream &os) {
|
||||
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getWarpPerTile(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getShapePerTile(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getRepetitions(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getContigPerThread(), os);
|
||||
}
|
||||
|
||||
static void printShared(const TritonGPUSharedEncodingAttr &attr,
|
||||
raw_ostream &os) {
|
||||
os << "_" << attr.getVec();
|
||||
os << "_" << attr.getPerPhase();
|
||||
os << "_" << attr.getMaxPhase();
|
||||
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
|
||||
}
|
||||
|
||||
template <class T> static void printBlocked(const T &attr, raw_ostream &os) {
|
||||
TritonGPUOpAsmInterface::printArray(attr.getThreadTileSize(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getWarpTileSize(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getBlockTileSize(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getBroadcastAxis(), os);
|
||||
}
|
||||
|
||||
static void printArray(const ArrayRef<unsigned> &array, raw_ostream &os,
|
||||
const std::string &delimiter = "x") {
|
||||
os << "_";
|
||||
if (array.empty()) {
|
||||
os << "none";
|
||||
return;
|
||||
}
|
||||
for (unsigned i = 0; i < array.size(); i++) {
|
||||
os << array[i];
|
||||
if (i != array.size() - 1) {
|
||||
os << delimiter;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void TritonGPUDialect::initialize() {
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton::gpu;
|
||||
@@ -11,54 +12,26 @@ using namespace mlir::triton::gpu;
|
||||
// TypeConverter
|
||||
//
|
||||
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
int numThreads)
|
||||
: context(context), numThreads(numThreads) {
|
||||
int numWarps)
|
||||
: context(context), numWarps(numWarps) {
|
||||
// TODO: how does MLIR pick the right conversion?
|
||||
addConversion([](Type type) { return type; });
|
||||
addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
|
||||
MLIRContext *context = this->context;
|
||||
int numThreads = this->numThreads;
|
||||
|
||||
llvm::ArrayRef<int64_t> shape = tensorType.getShape();
|
||||
Type elementType = tensorType.getElementType();
|
||||
int64_t rank = tensorType.getRank();
|
||||
int64_t numElements = tensorType.getNumElements();
|
||||
|
||||
// TODO: are there any better ways to raise this error?
|
||||
if (!(numElements >= numThreads)) {
|
||||
SmallVector<char> buffer;
|
||||
llvm::raw_svector_ostream os(buffer);
|
||||
os << tensorType << " has " << numElements << " numElements "
|
||||
<< " smaller than numThreads (" << numThreads << ")\n"
|
||||
<< "consider using smaller num-warps\n";
|
||||
llvm::report_fatal_error(os.str());
|
||||
}
|
||||
assert(numElements % numThreads == 0);
|
||||
|
||||
// or assert no encoding?
|
||||
|
||||
// Now we assume:
|
||||
// contiguous = 1, order = 0, 1, 2, ...,
|
||||
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
|
||||
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
|
||||
llvm::SmallVector<unsigned> blockTileSize(rank);
|
||||
// types with encoding are already in the right format
|
||||
// TODO: check for layout encodings specifically
|
||||
if (tensorType.getEncoding())
|
||||
return tensorType;
|
||||
// pessimistic values for attributes:
|
||||
// - 1 element per thread
|
||||
// - order = arange(rank)
|
||||
ArrayRef<int64_t> shape = tensorType.getShape();
|
||||
int rank = shape.size();
|
||||
llvm::SmallVector<unsigned> order(rank);
|
||||
llvm::SmallVector<unsigned> broadcastAxis;
|
||||
int remainingThreads = numThreads;
|
||||
int remainingLanes = /*warp size*/ 32;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
|
||||
warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim]));
|
||||
order[dim] = dim;
|
||||
|
||||
remainingThreads /= blockTileSize[dim];
|
||||
remainingLanes /= warpTileSize[dim];
|
||||
// TODO: will we need repetition?
|
||||
}
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
llvm::SmallVector<unsigned> sizePerThread(rank, 1);
|
||||
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
context, threadTileSize, warpTileSize, blockTileSize, order,
|
||||
broadcastAxis);
|
||||
return RankedTensorType::get(shape, elementType, encoding);
|
||||
this->context, shape, sizePerThread, order, this->numWarps);
|
||||
return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
|
||||
});
|
||||
|
||||
//
|
||||
@@ -86,8 +59,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// NOTE: only for remapped values.
|
||||
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
llvm_unreachable("Not implemented");
|
||||
return llvm::None;
|
||||
auto cast =
|
||||
builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType, inputs);
|
||||
return Optional<Value>(cast.getResult());
|
||||
// return Optional<Value>(cast.getResult(0));
|
||||
// llvm_unreachable("Not implemented");
|
||||
// return llvm::None;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -122,87 +99,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
return true;
|
||||
// // TODO: we should delete this
|
||||
// if (this->typeConverter.isLegal(dotOp))
|
||||
// return true;
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
// %dst = tt.broadcast %src
|
||||
// =>
|
||||
// %newSrc = convert_layout %src
|
||||
// %bcst = tt.broadcast %newSrc
|
||||
// %dst = convert_layout %bcst
|
||||
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
|
||||
int numThreads) {
|
||||
// collect broadcasts
|
||||
SmallVector<triton::BroadcastOp> broadcasts;
|
||||
mod.walk([&](triton::BroadcastOp op) { broadcasts.push_back(op); });
|
||||
|
||||
BlockAndValueMapping mapping;
|
||||
for (auto broadcast : broadcasts) {
|
||||
OpBuilder builder(broadcast);
|
||||
Value src = mapping.lookupOrDefault(broadcast.src());
|
||||
Type originSrcType = src.getType();
|
||||
Type originDstType = broadcast.getType();
|
||||
auto originDstTensorType = originDstType.dyn_cast<RankedTensorType>();
|
||||
unsigned dstRank = originDstTensorType.getRank();
|
||||
|
||||
// compute newSrcType & broadcastAxis
|
||||
Type newSrcType;
|
||||
SmallVector<unsigned> broadcastAxis;
|
||||
bool isSrcScalar = false;
|
||||
if (auto tensorType = originSrcType.dyn_cast<RankedTensorType>()) {
|
||||
assert(tensorType.getRank() == dstRank &&
|
||||
"src & dst should have same rank (verifier should catch this)");
|
||||
for (unsigned ax = 0; ax < dstRank; ++ax)
|
||||
if (tensorType.getShape()[ax] < originDstTensorType.getShape()[ax])
|
||||
broadcastAxis.push_back(ax);
|
||||
|
||||
Attribute originSrcEnc = tensorType.getEncoding();
|
||||
if (auto blockedEnc =
|
||||
originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
|
||||
blockedEnc.getContext(), blockedEnc.getThreadTileSize(),
|
||||
blockedEnc.getWarpTileSize(), blockedEnc.getBlockTileSize(),
|
||||
blockedEnc.getOrder(), broadcastAxis);
|
||||
newSrcType = RankedTensorType::get(
|
||||
tensorType.getShape(), tensorType.getElementType(), newSrcEnc);
|
||||
} else
|
||||
llvm_unreachable("src of broadcast should have blocked encoding");
|
||||
} else {
|
||||
for (unsigned ax = 0; ax < dstRank; ++ax)
|
||||
broadcastAxis.push_back(ax);
|
||||
newSrcType = originSrcType;
|
||||
isSrcScalar = true;
|
||||
}
|
||||
|
||||
// create new src
|
||||
if (!isSrcScalar) // we don't need to convert layout for scalar values
|
||||
src = builder.create<triton::gpu::ConvertLayoutOp>(src.getLoc(),
|
||||
newSrcType, src);
|
||||
|
||||
// create new broadcast
|
||||
// compute new type (encoding)
|
||||
auto originDstEnc = originDstTensorType.getEncoding()
|
||||
.dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto newEnc = TritonGPUBlockedEncodingAttr::get(
|
||||
originDstEnc.getContext(), originDstEnc.getThreadTileSize(),
|
||||
originDstEnc.getWarpTileSize(), originDstEnc.getBlockTileSize(),
|
||||
originDstEnc.getOrder(), broadcastAxis);
|
||||
auto newType =
|
||||
RankedTensorType::get(originDstTensorType.getShape(),
|
||||
originDstTensorType.getElementType(), newEnc);
|
||||
Value newBroadcast =
|
||||
builder.create<triton::BroadcastOp>(broadcast.getLoc(), newType, src);
|
||||
// we don't want to change the encoding of the result
|
||||
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
broadcast.getLoc(), originDstType, newBroadcast);
|
||||
|
||||
broadcast.replaceAllUsesWith(newDst);
|
||||
mapping.map(broadcast, newDst);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user