special encoding for broadcast
This commit is contained in:
@@ -4,8 +4,9 @@
|
|||||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||||
// include "mlir/IR/TensorEncoding.td"
|
// include "mlir/IR/TensorEncoding.td"
|
||||||
|
|
||||||
class TritonGPU_Attr<string name, list<Trait> traits = []>
|
class TritonGPU_Attr<string name, list<Trait> traits = [],
|
||||||
: AttrDef<TritonGPU_Dialect, name, traits>;
|
string baseCppClass = "::mlir::Attribute">
|
||||||
|
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass>;
|
||||||
|
|
||||||
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
|
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
|
||||||
let mnemonic = "shared_layout";
|
let mnemonic = "shared_layout";
|
||||||
@@ -104,7 +105,8 @@ And the associated TritonGPU MLIR
|
|||||||
ArrayRefParameter<
|
ArrayRefParameter<
|
||||||
"unsigned",
|
"unsigned",
|
||||||
"order of axes by the rate of changing"
|
"order of axes by the rate of changing"
|
||||||
>:$order
|
>:$order,
|
||||||
|
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||||
// "AffineMap":$threadOrdering,
|
// "AffineMap":$threadOrdering,
|
||||||
// "AffineMap":warpOrdering,
|
// "AffineMap":warpOrdering,
|
||||||
// "AffineMap":$blockOrdering,
|
// "AffineMap":$blockOrdering,
|
||||||
@@ -114,6 +116,28 @@ And the associated TritonGPU MLIR
|
|||||||
// let genVerifyDecl = 1;
|
// let genVerifyDecl = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TritonGPUBlockedMulticastEncodingAttr
|
||||||
|
: TritonGPU_Attr<"TritonGPUBlockedMulticastEncoding"> {
|
||||||
|
let mnemonic = "blocked_multicast_layout";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
to be broadcasted to blocked_layout
|
||||||
|
}];
|
||||||
|
|
||||||
|
// This needs to be synced with BlockedEncoding
|
||||||
|
let parameters = (
|
||||||
|
ins
|
||||||
|
ArrayRefParameter<"unsigned">:$threadTileSize,
|
||||||
|
ArrayRefParameter<"unsigned">:$warpTileSize,
|
||||||
|
ArrayRefParameter<"unsigned">:$blockTileSize,
|
||||||
|
ArrayRefParameter<"unsigned">:$order,
|
||||||
|
// unique to broadcasted layout
|
||||||
|
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||||
|
);
|
||||||
|
|
||||||
|
// let genVerifyDecl = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
||||||
let mnemonic = "mma_layout";
|
let mnemonic = "mma_layout";
|
||||||
|
|
||||||
@@ -131,7 +155,8 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
|||||||
ArrayRefParameter<"unsigned">:$shapePerTile,
|
ArrayRefParameter<"unsigned">:$shapePerTile,
|
||||||
// TODO: should Distributed layout also
|
// TODO: should Distributed layout also
|
||||||
ArrayRefParameter<"unsigned">:$repetitions,
|
ArrayRefParameter<"unsigned">:$repetitions,
|
||||||
ArrayRefParameter<"unsigned">:$contigPerThread
|
ArrayRefParameter<"unsigned">:$contigPerThread,
|
||||||
|
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||||
// "AffineMap":$warpOrdering,
|
// "AffineMap":$warpOrdering,
|
||||||
// "AffineMap":$blockOrdering
|
// "AffineMap":$blockOrdering
|
||||||
);
|
);
|
||||||
@@ -139,4 +164,28 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
|||||||
// let genVerifyDecl = 1;
|
// let genVerifyDecl = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TritonGPUMmaMulticastEncodingAttr
|
||||||
|
: TritonGPU_Attr<"TritonGPUMmaMulticastEncoding"> {
|
||||||
|
let mnemonic = "mma_multicast_layout";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
To be broadcasted to mma.
|
||||||
|
}];
|
||||||
|
|
||||||
|
// This needs to be synced with MmaEncoding
|
||||||
|
let parameters = (
|
||||||
|
ins
|
||||||
|
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
|
||||||
|
ArrayRefParameter<"unsigned">:$shapePerWarp,
|
||||||
|
ArrayRefParameter<"unsigned">:$warpPerTile,
|
||||||
|
ArrayRefParameter<"unsigned">:$shapePerTile,
|
||||||
|
ArrayRefParameter<"unsigned">:$repetitions,
|
||||||
|
ArrayRefParameter<"unsigned">:$contigPerThread,
|
||||||
|
// unique to broadcasted layout
|
||||||
|
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||||
|
);
|
||||||
|
|
||||||
|
// let genVerifyDecl = 1;
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -23,6 +23,9 @@ class TritonGPUConversionTarget : public ConversionTarget {
|
|||||||
TritonGPUTypeConverter &typeConverter;
|
TritonGPUTypeConverter &typeConverter;
|
||||||
public:
|
public:
|
||||||
explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
|
explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
|
||||||
|
|
||||||
|
/// update layouts & insert ConvertLayoutOp if necessary
|
||||||
|
LogicalResult refineLayouts(ModuleOp mod, int numThreads);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@@ -342,9 +342,13 @@ public:
|
|||||||
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
||||||
populateSCFPatterns(typeConverter, patterns);
|
populateSCFPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
if(failed(applyPartialConversion(mod, target,
|
if(failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||||
std::move(patterns))))
|
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
||||||
|
// update layouts
|
||||||
|
// broadcast src => multicast, dst => broadcasted
|
||||||
|
if(failed(target.refineLayouts(mod, numWarps)))
|
||||||
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -10,7 +10,7 @@ using namespace mlir::triton::gpu;
|
|||||||
// parse an array of integers
|
// parse an array of integers
|
||||||
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||||
const NamedAttribute &attr,
|
const NamedAttribute &attr,
|
||||||
SmallVector<unsigned, 2> &res,
|
/*SmallVector<unsigned, 2>*/auto &res,
|
||||||
StringRef desc) {
|
StringRef desc) {
|
||||||
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
||||||
if (!arrayAttr) {
|
if (!arrayAttr) {
|
||||||
@@ -36,8 +36,7 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
|||||||
#define GET_ATTRDEF_CLASSES
|
#define GET_ATTRDEF_CLASSES
|
||||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
||||||
|
|
||||||
Attribute
|
static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||||
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
|
||||||
if (parser.parseLess().failed())
|
if (parser.parseLess().failed())
|
||||||
return {};
|
return {};
|
||||||
// Parse the data as a dictionary
|
// Parse the data as a dictionary
|
||||||
@@ -51,28 +50,7 @@ TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
|||||||
SmallVector<unsigned, 2> warpTileSize;
|
SmallVector<unsigned, 2> warpTileSize;
|
||||||
SmallVector<unsigned, 2> blockTileSize;
|
SmallVector<unsigned, 2> blockTileSize;
|
||||||
SmallVector<unsigned, 2> order;
|
SmallVector<unsigned, 2> order;
|
||||||
|
SmallVector<unsigned, 2> broadcastAxis;
|
||||||
// parse an array of integers
|
|
||||||
// auto parseIntArrayAttr = [&parser](const NamedAttribute &attr,
|
|
||||||
// SmallVector<unsigned, 2> &res,
|
|
||||||
// StringRef desc) -> LogicalResult {
|
|
||||||
// auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
|
||||||
// if (!arrayAttr) {
|
|
||||||
// parser.emitError(parser.getNameLoc(), "expected an array for ")
|
|
||||||
// << desc;
|
|
||||||
// return failure();
|
|
||||||
// }
|
|
||||||
// for (Attribute i : arrayAttr) {
|
|
||||||
// auto intAttr = i.dyn_cast<IntegerAttr>();
|
|
||||||
// if (!intAttr) {
|
|
||||||
// parser.emitError(parser.getNameLoc(), "expected an integer value in ")
|
|
||||||
// << desc;
|
|
||||||
// return failure();
|
|
||||||
// }
|
|
||||||
// res.push_back(intAttr.getUInt());
|
|
||||||
// }
|
|
||||||
// return success();
|
|
||||||
// };
|
|
||||||
|
|
||||||
for (const NamedAttribute &attr : dict) {
|
for (const NamedAttribute &attr : dict) {
|
||||||
if (attr.getName() == "threadTileSize") {
|
if (attr.getName() == "threadTileSize") {
|
||||||
@@ -98,20 +76,39 @@ TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
|||||||
threadTileSize,
|
threadTileSize,
|
||||||
warpTileSize,
|
warpTileSize,
|
||||||
blockTileSize,
|
blockTileSize,
|
||||||
order);
|
order,
|
||||||
|
broadcastAxis);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
static void printBlocked(AsmPrinter &printer, auto *attr) {
|
||||||
printer << "<{"
|
printer << "<{"
|
||||||
<< "threadTileSize = [" << getThreadTileSize() << "]"
|
<< "threadTileSize = [" << attr->getThreadTileSize() << "]"
|
||||||
<< ", warpTileSize = [" << getWarpTileSize() << "]"
|
<< ", warpTileSize = [" << attr->getWarpTileSize() << "]"
|
||||||
<< ", blockTileSize = [" << getBlockTileSize() << "]"
|
<< ", blockTileSize = [" << attr->getBlockTileSize() << "]"
|
||||||
<< ", order = [" << getOrder() << "]"
|
<< ", order = [" << attr->getOrder() << "]"
|
||||||
|
<< ", broadcastAxis = [" << attr->getBroadcastAxis() << "]"
|
||||||
<< "}>";
|
<< "}>";
|
||||||
}
|
}
|
||||||
|
|
||||||
Attribute
|
Attribute
|
||||||
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||||
|
parseBlocked(parser, type);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||||
|
printBlocked(printer, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
Attribute
|
||||||
|
TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||||
|
parseBlocked(parser, type);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const {
|
||||||
|
printBlocked(printer, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Attribute parseMma(AsmParser &parser, Type type) {
|
||||||
if (parser.parseLess().failed())
|
if (parser.parseLess().failed())
|
||||||
return {};
|
return {};
|
||||||
DictionaryAttr dict;
|
DictionaryAttr dict;
|
||||||
@@ -126,6 +123,7 @@ TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
|||||||
SmallVector<unsigned, 2> shapePerTile;
|
SmallVector<unsigned, 2> shapePerTile;
|
||||||
SmallVector<unsigned, 2> repetitions;
|
SmallVector<unsigned, 2> repetitions;
|
||||||
SmallVector<unsigned, 2> contigPerThread;
|
SmallVector<unsigned, 2> contigPerThread;
|
||||||
|
SmallVector<unsigned, 2> broadcastAxis;
|
||||||
|
|
||||||
for (const NamedAttribute &attr : dict) {
|
for (const NamedAttribute &attr : dict) {
|
||||||
if (attr.getName() == "fragmentPerWarp") {
|
if (attr.getName() == "fragmentPerWarp") {
|
||||||
@@ -159,18 +157,37 @@ TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
|||||||
warpPerTile,
|
warpPerTile,
|
||||||
shapePerTile,
|
shapePerTile,
|
||||||
repetitions,
|
repetitions,
|
||||||
contigPerThread);
|
contigPerThread,
|
||||||
|
broadcastAxis);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printMma(AsmPrinter &printer, auto *attr) {
|
||||||
|
printer << "<{"
|
||||||
|
<< "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]"
|
||||||
|
<< ", shapePerWarp = [" << attr->getShapePerWarp() << "]"
|
||||||
|
<< ", warpPerTile = [" << attr->getWarpPerTile() << "]"
|
||||||
|
<< ", shapePerTile = [" << attr->getShapePerTile() << "]"
|
||||||
|
<< ", repetitions = [" << attr->getRepetitions() << "]"
|
||||||
|
<< ", contigPerThread = [" << attr->getContigPerThread() << "]"
|
||||||
|
<< "}>";
|
||||||
|
}
|
||||||
|
|
||||||
|
Attribute
|
||||||
|
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||||
|
return parseMma(parser, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||||
printer << "<{"
|
printMma(printer, this);
|
||||||
<< "fragmentPerWarp = [" << getFragmentPerWarp() << "]"
|
}
|
||||||
<< ", shapePerWarp = [" << getShapePerWarp() << "]"
|
|
||||||
<< ", warpPerTile = [" << getWarpPerTile() << "]"
|
Attribute
|
||||||
<< ", shapePerTile = [" << getShapePerTile() << "]"
|
TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||||
<< ", repetitions = [" << getRepetitions() << "]"
|
return parseMma(parser, type);
|
||||||
<< ", contigPerThread = [" << getContigPerThread() << "]"
|
}
|
||||||
<< "}>";
|
|
||||||
|
void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
|
||||||
|
printMma(printer, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
Attribute
|
Attribute
|
||||||
|
@@ -1,9 +1,11 @@
|
|||||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
using namespace mlir::triton::gpu;
|
||||||
|
|
||||||
//
|
//
|
||||||
// TypeConverter
|
// TypeConverter
|
||||||
@@ -41,6 +43,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
|
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
|
||||||
llvm::SmallVector<unsigned> blockTileSize(rank);
|
llvm::SmallVector<unsigned> blockTileSize(rank);
|
||||||
llvm::SmallVector<unsigned> order(rank);
|
llvm::SmallVector<unsigned> order(rank);
|
||||||
|
llvm::SmallVector<unsigned> broadcastAxis;
|
||||||
int remainingThreads = numThreads;
|
int remainingThreads = numThreads;
|
||||||
int remainingLanes = /*warp size*/32;
|
int remainingLanes = /*warp size*/32;
|
||||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
@@ -53,7 +56,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
// TODO: will we need repetition?
|
// TODO: will we need repetition?
|
||||||
}
|
}
|
||||||
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||||
context, threadTileSize, warpTileSize, blockTileSize, order);
|
context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis);
|
||||||
return RankedTensorType::get(shape, elementType, encoding);
|
return RankedTensorType::get(shape, elementType, encoding);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -81,7 +84,6 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
// NOTE: only for remapped values.
|
// NOTE: only for remapped values.
|
||||||
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||||
ValueRange inputs, Location loc) {
|
ValueRange inputs, Location loc) {
|
||||||
assert(inputs.size() == 1);
|
|
||||||
llvm_unreachable("Not implemented");
|
llvm_unreachable("Not implemented");
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
});
|
});
|
||||||
@@ -124,3 +126,98 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
|||||||
});
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// %dst = tt.broadcast %src
|
||||||
|
// =>
|
||||||
|
// %newSrc = convert_layout %src
|
||||||
|
// %bcst = tt.broadcast %newSrc
|
||||||
|
// %dst = convert_layout %bcst
|
||||||
|
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
|
||||||
|
int numThreads) {
|
||||||
|
// collect broadcasts
|
||||||
|
SmallVector<triton::BroadcastOp> broadcasts;
|
||||||
|
mod.walk([&](triton::BroadcastOp op) {
|
||||||
|
broadcasts.push_back(op);
|
||||||
|
});
|
||||||
|
|
||||||
|
BlockAndValueMapping mapping;
|
||||||
|
for (auto broadcast : broadcasts) {
|
||||||
|
OpBuilder builder(broadcast);
|
||||||
|
Value src = mapping.lookupOrDefault(broadcast.src());
|
||||||
|
Type originSrcType = src.getType();
|
||||||
|
Type originDstType = broadcast.getType();
|
||||||
|
auto originDstTensorType = originDstType.dyn_cast<RankedTensorType>();
|
||||||
|
unsigned dstRank = originDstTensorType.getRank();
|
||||||
|
|
||||||
|
// compute newSrcType & broadcastAxis
|
||||||
|
Type newSrcType;
|
||||||
|
SmallVector<unsigned> broadcastAxis;
|
||||||
|
bool isSrcScalar = false;
|
||||||
|
if (auto tensorType = originSrcType.dyn_cast<RankedTensorType>()) {
|
||||||
|
assert(tensorType.getRank() == dstRank &&
|
||||||
|
"src & dst should have same rank (verifier should catch this)");
|
||||||
|
for (unsigned ax = 0; ax < dstRank; ++ax)
|
||||||
|
if (tensorType.getShape()[ax] < originDstTensorType.getShape()[ax])
|
||||||
|
broadcastAxis.push_back(ax);
|
||||||
|
|
||||||
|
Attribute originSrcEnc = tensorType.getEncoding();
|
||||||
|
if (auto blockedEnc = originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||||
|
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
|
||||||
|
blockedEnc.getContext(),
|
||||||
|
blockedEnc.getThreadTileSize(),
|
||||||
|
blockedEnc.getWarpTileSize(),
|
||||||
|
blockedEnc.getBlockTileSize(),
|
||||||
|
blockedEnc.getOrder(),
|
||||||
|
broadcastAxis
|
||||||
|
);
|
||||||
|
newSrcType = RankedTensorType::get(
|
||||||
|
tensorType.getShape(),
|
||||||
|
tensorType.getElementType(),
|
||||||
|
newSrcEnc
|
||||||
|
);
|
||||||
|
} else
|
||||||
|
llvm_unreachable("src of broadcast should have blocked encoding");
|
||||||
|
} else {
|
||||||
|
for (unsigned ax = 0; ax < dstRank; ++ax)
|
||||||
|
broadcastAxis.push_back(ax);
|
||||||
|
newSrcType = originSrcType;
|
||||||
|
isSrcScalar = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// create new src
|
||||||
|
if (!isSrcScalar) // we don't need to convert layout for scalar values
|
||||||
|
src = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||||
|
src.getLoc(), newSrcType, src
|
||||||
|
);
|
||||||
|
|
||||||
|
// create new broadcast
|
||||||
|
// compute new type (encoding)
|
||||||
|
auto originDstEnc = originDstTensorType.getEncoding()
|
||||||
|
.dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||||
|
auto newEnc = TritonGPUBlockedEncodingAttr::get(
|
||||||
|
originDstEnc.getContext(),
|
||||||
|
originDstEnc.getThreadTileSize(),
|
||||||
|
originDstEnc.getWarpTileSize(),
|
||||||
|
originDstEnc.getBlockTileSize(),
|
||||||
|
originDstEnc.getOrder(),
|
||||||
|
broadcastAxis
|
||||||
|
);
|
||||||
|
auto newType = RankedTensorType::get(
|
||||||
|
originDstTensorType.getShape(),
|
||||||
|
originDstTensorType.getElementType(),
|
||||||
|
newEnc
|
||||||
|
);
|
||||||
|
Value newBroadcast = builder.create<triton::BroadcastOp>(
|
||||||
|
broadcast.getLoc(), newType, src
|
||||||
|
);
|
||||||
|
// we don't want to change the encoding of the result
|
||||||
|
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||||
|
broadcast.getLoc(), originDstType, newBroadcast
|
||||||
|
);
|
||||||
|
|
||||||
|
broadcast.replaceAllUsesWith(newDst);
|
||||||
|
mapping.map(broadcast, newDst);
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
@@ -45,32 +45,38 @@ module {
|
|||||||
%c32 = arith.constant 32 : index
|
%c32 = arith.constant 32 : index
|
||||||
%c0 = arith.constant 0 : index
|
%c0 = arith.constant 0 : index
|
||||||
%c256_i32 = arith.constant 256 : i32
|
%c256_i32 = arith.constant 256 : i32
|
||||||
%cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||||
%1 = arith.muli %0, %c256_i32 : i32
|
%1 = arith.muli %0, %c256_i32 : i32
|
||||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
%3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||||
%4 = arith.addi %3, %2 : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%4 = triton_gpu.convert_layout %3 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
%5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%5 = arith.addi %4, %2 : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
%6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>) -> tensor<256xi1, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%6 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||||
%7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%7 = triton_gpu.convert_layout %6 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
%8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%8 = "triton_gpu.cmpi"(%5, %7) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) -> tensor<256xi1, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
%9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%9 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||||
%10 = tt.getelementptr %9, %4 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%10 = triton_gpu.convert_layout %9 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
%11 = arith.index_cast %arg4 : i32 to index
|
%11 = tt.getelementptr %10, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
%12:3 = scf.for %arg6 = %c0 to %11 step %c32 iter_args(%arg7 = %cst, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>) {
|
%12 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||||
%15 = tt.load %arg8, %6, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%13 = triton_gpu.convert_layout %12 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
%16 = tt.load %arg9, %6, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%14 = tt.getelementptr %13, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
%17 = arith.addf %15, %16 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%15 = arith.index_cast %arg4 : i32 to index
|
||||||
%18 = arith.addf %arg7, %17 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%16:3 = scf.for %arg6 = %c0 to %15 step %c32 iter_args(%arg7 = %cst, %arg8 = %11, %arg9 = %14) -> (tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) {
|
||||||
%19 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%20 = tt.load %arg8, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
%20 = tt.getelementptr %arg8, %19 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%21 = tt.load %arg9, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
%21 = tt.getelementptr %arg9, %19 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%22 = arith.addf %20, %21 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
scf.yield %18, %20, %21 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%23 = arith.addf %arg7, %22 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
|
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||||
|
%25 = triton_gpu.convert_layout %24 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
|
%26 = tt.getelementptr %arg8, %25 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
|
%27 = tt.getelementptr %arg9, %25 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
|
scf.yield %23, %26, %27 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
}
|
}
|
||||||
%13 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%17 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||||
%14 = tt.getelementptr %13, %4 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%18 = triton_gpu.convert_layout %17 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
tt.store %14, %12#0, %6, : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
%19 = tt.getelementptr %18, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
|
tt.store %19, %16#0, %8, : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user