[OPTIMIZER] Rewrite patterns for layout conversions (#64)
This commit is contained in:
@@ -23,9 +23,9 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
using ::mlir::triton::gpu::TritonGPUBlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::TritonGPUMmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::TritonGPUSharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
@@ -226,7 +226,7 @@ static int64_t getLinearIndex(std::vector<int64_t> multidim_index,
|
||||
return linear_index;
|
||||
}
|
||||
|
||||
static unsigned getElemsPerThread(TritonGPUBlockedEncodingAttr layout,
|
||||
static unsigned getElemsPerThread(BlockedEncodingAttr layout,
|
||||
ArrayRef<int64_t> shape) {
|
||||
size_t rank = shape.size();
|
||||
SmallVector<unsigned> elemsPerThreadPerDim(rank);
|
||||
@@ -368,10 +368,10 @@ public:
|
||||
// Emit indices calculation within each ConversionPattern
|
||||
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
||||
// be eliminated in the consequent MLIR/LLVM optimization
|
||||
SmallVector<SmallVector<Value>> emitIndicesForBlockedLayout(
|
||||
Location loc, ConversionPatternRewriter &b,
|
||||
const TritonGPUBlockedEncodingAttr &blocked_layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b,
|
||||
const BlockedEncodingAttr &blocked_layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
auto cast = b.create<UnrealizedConversionCastOp>(
|
||||
loc, TypeRange{llvmIndexTy},
|
||||
@@ -483,7 +483,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto layout = tensorTy.getEncoding().cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto layout = tensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = rewriter.create<LLVM::BitcastOp>(loc, srcType, constVal);
|
||||
|
||||
@@ -594,9 +594,9 @@ struct StoreOpConversion
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
auto getLLVMElems = [&](Value value, Value llValue,
|
||||
const TritonGPUBlockedEncodingAttr &layout)
|
||||
-> SmallVector<Value, 4> {
|
||||
auto getLLVMElems =
|
||||
[&](Value value, Value llValue,
|
||||
const BlockedEncodingAttr &layout) -> SmallVector<Value, 4> {
|
||||
auto ty = value.getType().cast<RankedTensorType>();
|
||||
auto shape = ty.getShape();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
@@ -613,11 +613,11 @@ struct StoreOpConversion
|
||||
};
|
||||
|
||||
auto getLayout =
|
||||
[&](Value val) -> std::tuple<TritonGPUBlockedEncodingAttr, unsigned> {
|
||||
[&](Value val) -> std::tuple<BlockedEncodingAttr, unsigned> {
|
||||
auto ty = val.getType().cast<RankedTensorType>();
|
||||
auto shape = ty.getShape();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
auto layout = ty.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
|
||||
unsigned valueElems = getElemsPerThread(layout, shape);
|
||||
|
||||
@@ -633,9 +633,8 @@ struct StoreOpConversion
|
||||
auto maskElems = getLLVMElems(mask, llMask, maskLayout);
|
||||
assert(valueElems.size() == maskElems.size());
|
||||
|
||||
auto getAlign =
|
||||
[this](Value val,
|
||||
const TritonGPUBlockedEncodingAttr &layout) -> unsigned {
|
||||
auto getAlign = [this](Value val,
|
||||
const BlockedEncodingAttr &layout) -> unsigned {
|
||||
auto axisInfo = getAxisInfo(val);
|
||||
assert(axisInfo.hasValue());
|
||||
|
||||
@@ -648,9 +647,9 @@ struct StoreOpConversion
|
||||
};
|
||||
|
||||
// get align
|
||||
auto getVec = [this, &getAlign](
|
||||
Value val,
|
||||
const TritonGPUBlockedEncodingAttr &layout) -> unsigned {
|
||||
auto getVec = [this,
|
||||
&getAlign](Value val,
|
||||
const BlockedEncodingAttr &layout) -> unsigned {
|
||||
auto axisInfo = getAxisInfo(val);
|
||||
auto contig = axisInfo->getContiguity();
|
||||
// Here order should be ordered by contiguous first, so the first element
|
||||
@@ -820,10 +819,8 @@ struct BroadcastOpConversion
|
||||
Value result = op.result();
|
||||
auto srcTy = op.src().getType().cast<RankedTensorType>();
|
||||
auto resultTy = result.getType().cast<RankedTensorType>();
|
||||
auto srcLayout =
|
||||
srcTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto resultLayout =
|
||||
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
assert(srcLayout && (srcLayout == resultLayout) &&
|
||||
"Unexpected layout of BroadcastOp");
|
||||
auto srcShape = srcTy.getShape();
|
||||
@@ -894,8 +891,7 @@ struct ViewOpConversion
|
||||
// due to MLIR's restrictions
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().cast<RankedTensorType>();
|
||||
auto resultLayout =
|
||||
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultLayout, resultShape);
|
||||
Type elemTy =
|
||||
@@ -921,7 +917,7 @@ struct MakeRangeOpConversion
|
||||
auto rankedTy = op.result().getType().dyn_cast<RankedTensorType>();
|
||||
auto shape = rankedTy.getShape();
|
||||
auto blocked_layout =
|
||||
rankedTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
rankedTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto elemTy = rankedTy.getElementType();
|
||||
assert(elemTy.isInteger(32));
|
||||
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start());
|
||||
@@ -955,8 +951,7 @@ struct LoadOpConversion
|
||||
Value mask = adaptor.mask();
|
||||
Value other = adaptor.other();
|
||||
auto resultTy = op.result().getType().cast<RankedTensorType>();
|
||||
auto blockedLayout =
|
||||
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto blockedLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto shape = resultTy.getShape();
|
||||
|
||||
// TODO: Handle AxisInfo
|
||||
@@ -1166,8 +1161,7 @@ struct GEPOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::GEPOp> {
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
auto resultLayout =
|
||||
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultLayout, resultShape);
|
||||
Type elemTy =
|
||||
@@ -1206,8 +1200,8 @@ public:
|
||||
return failure();
|
||||
|
||||
Location loc = op->getLoc();
|
||||
auto resultLayout = resultTy.getEncoding()
|
||||
.template dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto resultLayout =
|
||||
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultLayout, resultShape);
|
||||
Type elemTy =
|
||||
@@ -1250,17 +1244,16 @@ public:
|
||||
|
||||
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
||||
Attribute layout = type.getEncoding();
|
||||
if (auto blocked_layout = layout.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
if (auto blocked_layout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
unsigned numElementsPerThread =
|
||||
getElemsPerThread(blocked_layout, type.getShape());
|
||||
SmallVector<Type, 4> types(numElementsPerThread,
|
||||
convertType(type.getElementType()));
|
||||
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
|
||||
} else if (auto mma_layout = layout.dyn_cast<TritonGPUMmaEncodingAttr>()) {
|
||||
} else if (auto mma_layout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
// TODO: Not implemented
|
||||
return llvm::None;
|
||||
} else if (auto shared_layout =
|
||||
layout.dyn_cast<TritonGPUSharedEncodingAttr>()) {
|
||||
} else if (auto shared_layout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||
// TODO: Not implemented
|
||||
return llvm::None;
|
||||
}
|
||||
|
@@ -156,8 +156,7 @@ struct TritonExpandDimsPattern
|
||||
Attribute _argEncoding = argType.getEncoding();
|
||||
if (!_argEncoding)
|
||||
return failure();
|
||||
auto argEncoding =
|
||||
_argEncoding.cast<triton::gpu::TritonGPUBlockedEncodingAttr>();
|
||||
auto argEncoding = _argEncoding.cast<triton::gpu::BlockedEncodingAttr>();
|
||||
// return shape
|
||||
auto retShape = argType.getShape().vec();
|
||||
retShape.insert(retShape.begin() + op.axis(), 1);
|
||||
@@ -170,10 +169,10 @@ struct TritonExpandDimsPattern
|
||||
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.axis(), 1);
|
||||
SmallVector<unsigned, 4> retOrder(retShape.size());
|
||||
std::iota(retOrder.begin(), retOrder.end(), 0);
|
||||
triton::gpu::TritonGPUBlockedEncodingAttr retEncoding =
|
||||
triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
getContext(), retSizePerThread, retThreadsPerWarp, retWarpsPerCTA,
|
||||
retOrder);
|
||||
triton::gpu::BlockedEncodingAttr retEncoding =
|
||||
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
|
||||
retThreadsPerWarp, retWarpsPerCTA,
|
||||
retOrder);
|
||||
// return type
|
||||
RankedTensorType retType =
|
||||
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
|
||||
@@ -201,16 +200,16 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
Value a = adaptor.a();
|
||||
Value b = adaptor.b();
|
||||
SmallVector<unsigned, 2> order{1, 0};
|
||||
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
|
||||
getContext(), 1, 1, 1, order);
|
||||
if (!aEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
Attribute encoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(aType.getShape(),
|
||||
aType.getElementType(), encoding);
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||
}
|
||||
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
|
||||
getContext(), 1, 1, 1, order);
|
||||
if (!bEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
Attribute encoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(bType.getShape(),
|
||||
bType.getElementType(), encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
|
Reference in New Issue
Block a user