[Triton-MLIR][BACKEND] some code clean on the backend (#978)

This commit is contained in:
Yan Chunwei
2022-12-12 17:46:16 +08:00
committed by GitHub
parent e5cfa0f633
commit 0cfe909df8
4 changed files with 97 additions and 137 deletions

View File

@@ -105,11 +105,9 @@ struct DotOpMmaV1ConversionHelper {
}
// Get the number of fp16x2 elements for $a.
// \param shapeTransed: the shape or reordered shape if transpose needed.
// \param shapeTransed: A's shape or reordered shape if transpose needed.
// \param orderTransed: the order or reordered order if transpose needed.
unsigned getNumM(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
bool isARow = orderTransed[0] != 0;
unsigned getNumM(ArrayRef<int64_t> shapeTransed, bool isARow) const {
AParam param(isARow);
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
@@ -117,11 +115,9 @@ struct DotOpMmaV1ConversionHelper {
}
// Get the number of fp16x2 elements for $b.
// \param shapeTransed: the shape or reordered shape if transpose needed.
// \param shapeTransed: B' shape or reordered shape if transpose needed.
// \param orderTransed: the order or reordered order if transpose needed.
unsigned getNumN(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
bool isBRow = orderTransed[0] != 0;
unsigned getNumN(ArrayRef<int64_t> shapeTransed, bool isBRow) const {
BParam param(isBRow);
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
@@ -130,7 +126,7 @@ struct DotOpMmaV1ConversionHelper {
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
int numM = getNumM(shapeTransed, orderTransed);
int numM = getNumM(shapeTransed, orderTransed[0] == 1);
int NK = shapeTransed[1];
// NOTE: We couldn't get the vec from the shared layout.
@@ -143,7 +139,7 @@ struct DotOpMmaV1ConversionHelper {
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
unsigned numN = getNumN(shapeTransed, orderTransed);
unsigned numN = getNumN(shapeTransed, orderTransed[0] == 1);
int NK = shapeTransed[0];
// NOTE: We couldn't get the vec from the shared layout.
// int vecB = sharedLayout.getVec();
@@ -1451,7 +1447,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
}
};
unsigned numM = getNumM(shape, order);
unsigned numM = getNumM(shape, order[0] == 1);
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
loadA(m, k);
@@ -1563,7 +1559,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
}
};
unsigned numN = getNumN(shape, order);
unsigned numN = getNumN(shape, order[0] == 1);
for (unsigned k = 0; k < NK; k += 4)
for (unsigned n = 0; n < numN / 2; ++n) {
if (!hbs.count({n, k}))

View File

@@ -739,7 +739,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
auto tensorTy = resType.cast<RankedTensorType>();
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
auto tensorTy = resType.cast<RankedTensorType>();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(constVal, srcType);
size_t elemsPerThread = getElemsPerThread(tensorTy);
@@ -981,7 +980,7 @@ struct LoadOpConversion
size_t size = width / valueElemNbits;
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
Value v = undef(vecTy);
for (size_t s = 0; s < size; ++s) {
Value falseVal = otherElems[vecStart + ii * size + s];
Value sVal = createIndexAttrConstant(
@@ -1118,7 +1117,7 @@ struct StoreOpConversion
SmallVector<std::pair<Value, std::string>> asmArgs;
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
// llWord is a width-len composition
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
Value llWord = undef(wordTy);
// Insert each value element to the composition
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
@@ -1129,10 +1128,7 @@ struct StoreOpConversion
elem = bitcast(elem, valueElemTy);
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
llWord =
insert_element(wordTy, llWord, elem,
rewriter.create<LLVM::ConstantOp>(
loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx)));
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
}
llWord = bitcast(llWord, valArgTy);
std::string constraint =
@@ -3570,43 +3566,27 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
auto DTensorTy = D.getType().cast<RankedTensorType>();
SmallVector<int> AShape(ATensorTy.getShape().begin(),
ATensorTy.getShape().end());
SmallVector<int> BShape(BTensorTy.getShape().begin(),
BTensorTy.getShape().end());
auto AShape = ATensorTy.getShape();
auto BShape = BTensorTy.getShape();
auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
// TODO[Superjomn]: ld.v4 is not supported.
isAVec4 = true;
isBVec4 = true;
int packSize0 = (isARow || isAVec4) ? 1 : 2;
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({2 * packSize0, 2 * packSize1, 1});
SmallVector<int> spw({fpw[0] * 4 * rep[0], fpw[1] * 4 * rep[1], 1});
Value loadedA = adaptor.a();
Value loadedB = adaptor.b();
Value loadedC = adaptor.c();
DotOpMmaV1ConversionHelper helper(mmaLayout);
unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]);
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[1]);
unsigned numM = helper.getNumM(AShape, isARow);
unsigned numN = helper.getNumN(BShape, isBRow);
unsigned NK = AShape[1];
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter);
auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter);
auto has = helper.extractLoadedOperand(adaptor.a(), NK, rewriter);
auto hbs = helper.extractLoadedOperand(adaptor.b(), NK, rewriter);
// Initialize accumulators with external values, the acc holds the accumulator
// value that is shared between the MMA instructions inside a DotOp, we can
// call the order of the values the accumulator-internal order.
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
SmallVector<Value> acc = getElementsFromStruct(loc, adaptor.c(), rewriter);
size_t resSize = acc.size();
// The resVals holds the final result of the DotOp.
@@ -3719,38 +3699,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
auto bShape = bTensorTy.getShape();
auto cShape = cTensorTy.getShape();
ValueTable has, hbs;
int mShapePerCTA{-1}, nShapePerCTA{-1};
int mSizePerThread{-1}, nSizePerThread{-1};
ArrayRef<unsigned> aOrder, bOrder;
Value llA, llB;
BlockedEncodingAttr dLayout =
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
auto order = dLayout.getOrder();
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
DotOpFMAConversionHelper helper(dLayout);
if (auto aDotOpLayout =
aTensorTy.getEncoding()
.dyn_cast<DotOperandEncodingAttr>()) { // get input from
// convert_layout
auto bDotOpLayout =
bTensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
auto aDotOpLayout = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto bDotOpLayout = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
assert(bLayout);
llA = adaptor.a();
llB = adaptor.b();
} else if (auto aLayout =
aTensorTy.getEncoding()
.dyn_cast<SharedEncodingAttr>()) { // load input from smem
auto bLayout = bTensorTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(bLayout);
Value thread = getThreadId(rewriter, loc);
llA = helper.loadA(A, adaptor.a(), dLayout, thread, loc, rewriter);
llB = helper.loadB(B, adaptor.b(), dLayout, thread, loc, rewriter);
}
Value llA = adaptor.a();
Value llB = adaptor.b();
auto sizePerThread = getSizePerThread(dLayout);
auto shapePerCTA = getShapePerCTA(dLayout);
@@ -3759,17 +3720,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
int M = aShape[0];
int N = bShape[1];
mShapePerCTA = order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
mSizePerThread =
int mShapePerCTA =
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int mSizePerThread =
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
nShapePerCTA = order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
nSizePerThread =
int nShapePerCTA =
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int nSizePerThread =
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA, mSizePerThread,
rewriter, loc);
hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA, nSizePerThread,
rewriter, loc);
auto has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA,
mSizePerThread, rewriter, loc);
auto hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA,
nSizePerThread, rewriter, loc);
SmallVector<Value> ret = cc;
for (unsigned k = 0; k < K; k++) {
@@ -3780,7 +3743,6 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
hbs[{n + nn, k}], ret[z]);
++z;
}
}
@@ -4310,9 +4272,10 @@ struct ExpOpConversionApprox
// For FP64 input, call __nv_expf for higher-precision calculation
if (elemTy.getIntOrFloatBitWidth() == 64)
return {};
const double log2e = 1.4426950408889634;
Value prod =
rewriter.create<LLVM::FMulOp>(loc, f32_ty, operands[0], f32_val(log2e));
Value prod = fmul(f32_ty, operands[0], f32_val(log2e));
PTXBuilder ptxBuilder;
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
auto output = ptxBuilder.newOperand("=f");

View File

@@ -31,6 +31,7 @@
#include <numeric>
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
// Operators
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
@@ -40,6 +41,7 @@
#define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__)
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
#define fmul(...) rewriter.create<LLVM::FMulOp>(loc, __VA_ARGS__)
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
@@ -90,6 +92,8 @@
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
// Types
#define i32_ty rewriter.getIntegerType(32)
#define ui32_ty rewriter.getIntegerType(32, false)
#define f16_ty rewriter.getF16Type()
@@ -102,8 +106,9 @@
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
// Creator for constant
// Constants
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
#define int_val(width, val) \
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)

View File

@@ -36,7 +36,7 @@ namespace {
class DecomposeDotOperand : public mlir::RewritePattern {
public:
DecomposeDotOperand(mlir::MLIRContext *context)
explicit DecomposeDotOperand(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
@@ -84,7 +84,7 @@ public:
// IIUC they are therefore not handled by DRR right now
class SimplifyConversion : public mlir::RewritePattern {
public:
SimplifyConversion(mlir::MLIRContext *context)
explicit SimplifyConversion(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
4, context) {}
@@ -219,8 +219,8 @@ public:
//
// -----------------------------------------------------------------------------
static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret) {
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret) {
ret = targetEncoding;
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
ret = triton::gpu::SliceEncodingAttr::get(
@@ -246,7 +246,7 @@ inline bool expensive_to_remat(Operation *op) {
if (isa<scf::YieldOp, scf::ForOp>(op))
return true;
return false;
};
}
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
BlockAndValueMapping &mapping) {
@@ -276,7 +276,7 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
// are reachable from it without passing through any memory operation.
class RematerializeBackward : public mlir::RewritePattern {
public:
RematerializeBackward(mlir::MLIRContext *context)
explicit RematerializeBackward(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
2, context) {}
@@ -303,7 +303,7 @@ public:
SetVector<Attribute> layout;
llvm::MapVector<Value, Attribute> toConvert;
std::vector<std::pair<Operation *, Attribute>> queue;
queue.push_back({cvt, targetType.getEncoding()});
queue.emplace_back(cvt, targetType.getEncoding());
int numCvts = 1;
while (!queue.empty()) {
Operation *currOp;
@@ -341,7 +341,7 @@ public:
continue;
// we add one expensive conversion for the current operand
numCvts += 1;
queue.push_back({opArgI, newEncoding});
queue.emplace_back(opArgI, newEncoding);
}
}
// if rematerialization would add more conversions than it removes
@@ -351,8 +351,8 @@ public:
SmallVector<Value, 4> sortedValues;
SetVector<Operation *> tmp;
for (auto it = toConvert.begin(); it != toConvert.end(); ++it) {
Value v = it->first;
for (auto &item : toConvert) {
Value v = item.first;
if (v.getDefiningOp())
tmp.insert(v.getDefiningOp());
else
@@ -393,7 +393,7 @@ public:
class MoveConvertOutOfLoop : public mlir::RewritePattern {
public:
MoveConvertOutOfLoop(mlir::MLIRContext *context)
explicit MoveConvertOutOfLoop(mlir::MLIRContext *context)
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
SmallVector<Value, 4>
@@ -406,7 +406,7 @@ public:
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
// Clone for loop
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newInitArgs);
newForOp->moveBefore(forOp);
@@ -455,7 +455,7 @@ public:
mlir::PatternRewriter &rewriter) const override {
auto forOp = cast<scf::ForOp>(op);
auto iterArgs = forOp.getRegionIterArgs();
for (auto iterArg : llvm::enumerate(iterArgs)) {
for (const auto &iterArg : llvm::enumerate(iterArgs)) {
// if (iterArg.index() != 1)
// continue;
// skip non-tensor types
@@ -517,7 +517,7 @@ public:
class RematerializeForward : public mlir::RewritePattern {
public:
RematerializeForward(mlir::MLIRContext *context)
explicit RematerializeForward(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
2, context) {}
@@ -584,7 +584,7 @@ public:
//
// -----------------------------------------------------------------------------
namespace {
static int computeCapabilityToMMAVersion(int computeCapability) {
int computeCapabilityToMMAVersion(int computeCapability) {
if (computeCapability < 80) {
return 1;
} else if (computeCapability < 90) {
@@ -595,9 +595,7 @@ static int computeCapabilityToMMAVersion(int computeCapability) {
}
}
static SmallVector<int64_t, 2>
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
int numWarps) {
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
if (version == 1)
return {16, 16};
else if (version == 2)
@@ -608,12 +606,11 @@ mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
}
}
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(1, shape, numWarps);
mmaVersionToShapePerWarp(1 /*version*/);
bool changed = false;
do {
changed = false;
@@ -669,7 +666,7 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
class OptimizeBlockedToShared : public mlir::RewritePattern {
public:
OptimizeBlockedToShared(mlir::MLIRContext *context)
explicit OptimizeBlockedToShared(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
@@ -717,7 +714,7 @@ public:
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
public:
OptimizeConvertToDotOperand(mlir::MLIRContext *context)
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
@@ -729,11 +726,12 @@ public:
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if(auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
if (auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if(auto srcSharedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
else if (auto srcSharedLayout =
srcType.getEncoding()
.dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return failure();
@@ -742,20 +740,18 @@ public:
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return failure();
unsigned opIdx = dstDotOperandLayout.getOpIdx();
if(!dstDotOperandLayout.getIsMMAv1Row())
if (!dstDotOperandLayout.getIsMMAv1Row())
return failure();
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if((order[0] == 1 && isMMAv1Row) ||
(order[0] == 0 && !isMMAv1Row))
bool isMMAv1Row =
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
newIsRow);
op->getContext(), dstDotOperandLayout.getOpIdx(),
dstDotOperandLayout.getParent(), newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(),
dstType.getElementType(), newDstEncoding);
dstType.getShape(), dstType.getElementType(), newDstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newDstType, cvt.getOperand());
rewriter.replaceOp(op, newCvt.getResult());
@@ -763,7 +759,6 @@ public:
}
};
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
@@ -777,7 +772,7 @@ public:
int version, int numWarps) {
switch (version) {
case 1:
return warpsPerTileV1(dotOp, shape, numWarps);
return warpsPerTileV1(shape, numWarps);
case 2:
return warpsPerTileV2(dotOp, shape, numWarps);
default:
@@ -821,27 +816,31 @@ public:
Value b = dotOp.b();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
auto oldAOrder = oldAType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
auto oldBOrder = oldBType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
Attribute isMMAv1RowA;
Attribute isMMAv1RowB;
if(version == 1){
if (version == 1) {
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding(),
isMMAv1RowA));
triton::gpu::DotOperandEncodingAttr::get(
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding(),
isMMAv1RowB));
triton::gpu::DotOperandEncodingAttr::get(
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
@@ -857,9 +856,8 @@ public:
class FixupLoop : public mlir::RewritePattern {
public:
FixupLoop(mlir::MLIRContext *context)
: mlir::RewritePattern(scf::ForOp::getOperationName(), 2,
context) {}
explicit FixupLoop(mlir::MLIRContext *context)
: mlir::RewritePattern(scf::ForOp::getOperationName(), 2, context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
@@ -869,15 +867,15 @@ public:
// Rewrite init argument
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
bool shouldRematerialize = false;
for(size_t i = 0; i < newInitArgs.size(); i++){
for (size_t i = 0; i < newInitArgs.size(); i++) {
auto initArg = newInitArgs[i];
auto regionArg = forOp.getRegionIterArgs()[i];
if(newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()){
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) {
shouldRematerialize = true;
break;
}
}
if(!shouldRematerialize)
if (!shouldRematerialize)
return failure();
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
@@ -894,8 +892,6 @@ public:
}
rewriter.replaceOp(forOp, newForOp.getResults());
return success();
}
};