[Triton-MLIR][BACKEND] some code clean on the backend (#978)
This commit is contained in:
@@ -36,7 +36,7 @@ namespace {
|
||||
class DecomposeDotOperand : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
DecomposeDotOperand(mlir::MLIRContext *context)
|
||||
explicit DecomposeDotOperand(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
@@ -84,7 +84,7 @@ public:
|
||||
// IIUC they are therefore not handled by DRR right now
|
||||
class SimplifyConversion : public mlir::RewritePattern {
|
||||
public:
|
||||
SimplifyConversion(mlir::MLIRContext *context)
|
||||
explicit SimplifyConversion(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
4, context) {}
|
||||
|
||||
@@ -219,8 +219,8 @@ public:
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
Attribute &ret) {
|
||||
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
Attribute &ret) {
|
||||
ret = targetEncoding;
|
||||
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||
ret = triton::gpu::SliceEncodingAttr::get(
|
||||
@@ -246,7 +246,7 @@ inline bool expensive_to_remat(Operation *op) {
|
||||
if (isa<scf::YieldOp, scf::ForOp>(op))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
}
|
||||
|
||||
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
BlockAndValueMapping &mapping) {
|
||||
@@ -276,7 +276,7 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
// are reachable from it without passing through any memory operation.
|
||||
class RematerializeBackward : public mlir::RewritePattern {
|
||||
public:
|
||||
RematerializeBackward(mlir::MLIRContext *context)
|
||||
explicit RematerializeBackward(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
2, context) {}
|
||||
|
||||
@@ -303,7 +303,7 @@ public:
|
||||
SetVector<Attribute> layout;
|
||||
llvm::MapVector<Value, Attribute> toConvert;
|
||||
std::vector<std::pair<Operation *, Attribute>> queue;
|
||||
queue.push_back({cvt, targetType.getEncoding()});
|
||||
queue.emplace_back(cvt, targetType.getEncoding());
|
||||
int numCvts = 1;
|
||||
while (!queue.empty()) {
|
||||
Operation *currOp;
|
||||
@@ -341,7 +341,7 @@ public:
|
||||
continue;
|
||||
// we add one expensive conversion for the current operand
|
||||
numCvts += 1;
|
||||
queue.push_back({opArgI, newEncoding});
|
||||
queue.emplace_back(opArgI, newEncoding);
|
||||
}
|
||||
}
|
||||
// if rematerialization would add more conversions than it removes
|
||||
@@ -351,8 +351,8 @@ public:
|
||||
|
||||
SmallVector<Value, 4> sortedValues;
|
||||
SetVector<Operation *> tmp;
|
||||
for (auto it = toConvert.begin(); it != toConvert.end(); ++it) {
|
||||
Value v = it->first;
|
||||
for (auto &item : toConvert) {
|
||||
Value v = item.first;
|
||||
if (v.getDefiningOp())
|
||||
tmp.insert(v.getDefiningOp());
|
||||
else
|
||||
@@ -393,7 +393,7 @@ public:
|
||||
|
||||
class MoveConvertOutOfLoop : public mlir::RewritePattern {
|
||||
public:
|
||||
MoveConvertOutOfLoop(mlir::MLIRContext *context)
|
||||
explicit MoveConvertOutOfLoop(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
|
||||
|
||||
SmallVector<Value, 4>
|
||||
@@ -406,7 +406,7 @@ public:
|
||||
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
|
||||
// Clone for loop
|
||||
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
||||
auto newForOp = rewriter.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), newInitArgs);
|
||||
newForOp->moveBefore(forOp);
|
||||
@@ -455,7 +455,7 @@ public:
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
auto iterArgs = forOp.getRegionIterArgs();
|
||||
for (auto iterArg : llvm::enumerate(iterArgs)) {
|
||||
for (const auto &iterArg : llvm::enumerate(iterArgs)) {
|
||||
// if (iterArg.index() != 1)
|
||||
// continue;
|
||||
// skip non-tensor types
|
||||
@@ -517,7 +517,7 @@ public:
|
||||
|
||||
class RematerializeForward : public mlir::RewritePattern {
|
||||
public:
|
||||
RematerializeForward(mlir::MLIRContext *context)
|
||||
explicit RematerializeForward(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
2, context) {}
|
||||
|
||||
@@ -584,7 +584,7 @@ public:
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
namespace {
|
||||
static int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
if (computeCapability < 80) {
|
||||
return 1;
|
||||
} else if (computeCapability < 90) {
|
||||
@@ -595,9 +595,7 @@ static int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
}
|
||||
}
|
||||
|
||||
static SmallVector<int64_t, 2>
|
||||
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
||||
int numWarps) {
|
||||
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
|
||||
if (version == 1)
|
||||
return {16, 16};
|
||||
else if (version == 2)
|
||||
@@ -608,12 +606,11 @@ mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp =
|
||||
mmaVersionToShapePerWarp(1, shape, numWarps);
|
||||
mmaVersionToShapePerWarp(1 /*version*/);
|
||||
bool changed = false;
|
||||
do {
|
||||
changed = false;
|
||||
@@ -669,7 +666,7 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
|
||||
class OptimizeBlockedToShared : public mlir::RewritePattern {
|
||||
public:
|
||||
OptimizeBlockedToShared(mlir::MLIRContext *context)
|
||||
explicit OptimizeBlockedToShared(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
@@ -717,7 +714,7 @@ public:
|
||||
|
||||
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
||||
public:
|
||||
OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
@@ -729,11 +726,12 @@ public:
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
// order
|
||||
ArrayRef<unsigned> order;
|
||||
if(auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
if (auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
order = srcBlockedLayout.getOrder();
|
||||
else if(auto srcSharedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
else if (auto srcSharedLayout =
|
||||
srcType.getEncoding()
|
||||
.dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
order = srcSharedLayout.getOrder();
|
||||
else
|
||||
return failure();
|
||||
@@ -742,20 +740,18 @@ public:
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!dstDotOperandLayout)
|
||||
return failure();
|
||||
unsigned opIdx = dstDotOperandLayout.getOpIdx();
|
||||
if(!dstDotOperandLayout.getIsMMAv1Row())
|
||||
if (!dstDotOperandLayout.getIsMMAv1Row())
|
||||
return failure();
|
||||
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if((order[0] == 1 && isMMAv1Row) ||
|
||||
(order[0] == 0 && !isMMAv1Row))
|
||||
bool isMMAv1Row =
|
||||
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
|
||||
return failure();
|
||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
|
||||
newIsRow);
|
||||
op->getContext(), dstDotOperandLayout.getOpIdx(),
|
||||
dstDotOperandLayout.getParent(), newIsRow);
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(),
|
||||
dstType.getElementType(), newDstEncoding);
|
||||
dstType.getShape(), dstType.getElementType(), newDstEncoding);
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newDstType, cvt.getOperand());
|
||||
rewriter.replaceOp(op, newCvt.getResult());
|
||||
@@ -763,7 +759,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
|
||||
@@ -777,7 +772,7 @@ public:
|
||||
int version, int numWarps) {
|
||||
switch (version) {
|
||||
case 1:
|
||||
return warpsPerTileV1(dotOp, shape, numWarps);
|
||||
return warpsPerTileV1(shape, numWarps);
|
||||
case 2:
|
||||
return warpsPerTileV2(dotOp, shape, numWarps);
|
||||
default:
|
||||
@@ -821,27 +816,31 @@ public:
|
||||
Value b = dotOp.b();
|
||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
|
||||
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
|
||||
auto oldAOrder = oldAType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
auto oldBOrder = oldBType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
Attribute isMMAv1RowA;
|
||||
Attribute isMMAv1RowB;
|
||||
if(version == 1){
|
||||
if (version == 1) {
|
||||
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
||||
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
||||
}
|
||||
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
|
||||
newRetType.getEncoding(),
|
||||
isMMAv1RowA));
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
|
||||
newRetType.getEncoding(),
|
||||
isMMAv1RowB));
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
|
||||
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
@@ -857,9 +856,8 @@ public:
|
||||
class FixupLoop : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
FixupLoop(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 2,
|
||||
context) {}
|
||||
explicit FixupLoop(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 2, context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
@@ -869,17 +867,17 @@ public:
|
||||
// Rewrite init argument
|
||||
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
||||
bool shouldRematerialize = false;
|
||||
for(size_t i = 0; i < newInitArgs.size(); i++){
|
||||
for (size_t i = 0; i < newInitArgs.size(); i++) {
|
||||
auto initArg = newInitArgs[i];
|
||||
auto regionArg = forOp.getRegionIterArgs()[i];
|
||||
if(newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()){
|
||||
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) {
|
||||
shouldRematerialize = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!shouldRematerialize)
|
||||
if (!shouldRematerialize)
|
||||
return failure();
|
||||
|
||||
|
||||
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), newInitArgs);
|
||||
@@ -894,8 +892,6 @@ public:
|
||||
}
|
||||
rewriter.replaceOp(forOp, newForOp.getResults());
|
||||
return success();
|
||||
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user