[Triton-MLIR][BACKEND] some code clean on the backend (#978)
This commit is contained in:
@@ -105,11 +105,9 @@ struct DotOpMmaV1ConversionHelper {
|
||||
}
|
||||
|
||||
// Get the number of fp16x2 elements for $a.
|
||||
// \param shapeTransed: the shape or reordered shape if transpose needed.
|
||||
// \param shapeTransed: A's shape or reordered shape if transpose needed.
|
||||
// \param orderTransed: the order or reordered order if transpose needed.
|
||||
unsigned getNumM(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
bool isARow = orderTransed[0] != 0;
|
||||
unsigned getNumM(ArrayRef<int64_t> shapeTransed, bool isARow) const {
|
||||
AParam param(isARow);
|
||||
|
||||
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
|
||||
@@ -117,11 +115,9 @@ struct DotOpMmaV1ConversionHelper {
|
||||
}
|
||||
|
||||
// Get the number of fp16x2 elements for $b.
|
||||
// \param shapeTransed: the shape or reordered shape if transpose needed.
|
||||
// \param shapeTransed: B' shape or reordered shape if transpose needed.
|
||||
// \param orderTransed: the order or reordered order if transpose needed.
|
||||
unsigned getNumN(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
bool isBRow = orderTransed[0] != 0;
|
||||
unsigned getNumN(ArrayRef<int64_t> shapeTransed, bool isBRow) const {
|
||||
BParam param(isBRow);
|
||||
|
||||
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
|
||||
@@ -130,7 +126,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
|
||||
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
int numM = getNumM(shapeTransed, orderTransed);
|
||||
int numM = getNumM(shapeTransed, orderTransed[0] == 1);
|
||||
int NK = shapeTransed[1];
|
||||
|
||||
// NOTE: We couldn't get the vec from the shared layout.
|
||||
@@ -143,7 +139,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
|
||||
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
unsigned numN = getNumN(shapeTransed, orderTransed);
|
||||
unsigned numN = getNumN(shapeTransed, orderTransed[0] == 1);
|
||||
int NK = shapeTransed[0];
|
||||
// NOTE: We couldn't get the vec from the shared layout.
|
||||
// int vecB = sharedLayout.getVec();
|
||||
@@ -1451,7 +1447,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
}
|
||||
};
|
||||
|
||||
unsigned numM = getNumM(shape, order);
|
||||
unsigned numM = getNumM(shape, order[0] == 1);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
loadA(m, k);
|
||||
@@ -1563,7 +1559,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
}
|
||||
};
|
||||
|
||||
unsigned numN = getNumN(shape, order);
|
||||
unsigned numN = getNumN(shape, order[0] == 1);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
if (!hbs.count({n, k}))
|
||||
|
@@ -739,7 +739,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
|
||||
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = bitcast(constVal, srcType);
|
||||
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
||||
@@ -981,7 +980,7 @@ struct LoadOpConversion
|
||||
size_t size = width / valueElemNbits;
|
||||
|
||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
||||
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
Value v = undef(vecTy);
|
||||
for (size_t s = 0; s < size; ++s) {
|
||||
Value falseVal = otherElems[vecStart + ii * size + s];
|
||||
Value sVal = createIndexAttrConstant(
|
||||
@@ -1118,7 +1117,7 @@ struct StoreOpConversion
|
||||
SmallVector<std::pair<Value, std::string>> asmArgs;
|
||||
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
// llWord is a width-len composition
|
||||
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
||||
Value llWord = undef(wordTy);
|
||||
// Insert each value element to the composition
|
||||
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
|
||||
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
|
||||
@@ -1129,10 +1128,7 @@ struct StoreOpConversion
|
||||
elem = bitcast(elem, valueElemTy);
|
||||
|
||||
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
|
||||
llWord =
|
||||
insert_element(wordTy, llWord, elem,
|
||||
rewriter.create<LLVM::ConstantOp>(
|
||||
loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx)));
|
||||
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
|
||||
}
|
||||
llWord = bitcast(llWord, valArgTy);
|
||||
std::string constraint =
|
||||
@@ -3570,43 +3566,27 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
auto DTensorTy = D.getType().cast<RankedTensorType>();
|
||||
SmallVector<int> AShape(ATensorTy.getShape().begin(),
|
||||
ATensorTy.getShape().end());
|
||||
SmallVector<int> BShape(BTensorTy.getShape().begin(),
|
||||
BTensorTy.getShape().end());
|
||||
auto AShape = ATensorTy.getShape();
|
||||
auto BShape = BTensorTy.getShape();
|
||||
auto DShape = DTensorTy.getShape();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
||||
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
|
||||
// TODO[Superjomn]: ld.v4 is not supported.
|
||||
isAVec4 = true;
|
||||
isBVec4 = true;
|
||||
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
SmallVector<int> rep({2 * packSize0, 2 * packSize1, 1});
|
||||
SmallVector<int> spw({fpw[0] * 4 * rep[0], fpw[1] * 4 * rep[1], 1});
|
||||
|
||||
Value loadedA = adaptor.a();
|
||||
Value loadedB = adaptor.b();
|
||||
Value loadedC = adaptor.c();
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
|
||||
unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]);
|
||||
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[1]);
|
||||
unsigned numM = helper.getNumM(AShape, isARow);
|
||||
unsigned numN = helper.getNumN(BShape, isBRow);
|
||||
unsigned NK = AShape[1];
|
||||
|
||||
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter);
|
||||
auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter);
|
||||
auto has = helper.extractLoadedOperand(adaptor.a(), NK, rewriter);
|
||||
auto hbs = helper.extractLoadedOperand(adaptor.b(), NK, rewriter);
|
||||
|
||||
// Initialize accumulators with external values, the acc holds the accumulator
|
||||
// value that is shared between the MMA instructions inside a DotOp, we can
|
||||
// call the order of the values the accumulator-internal order.
|
||||
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
|
||||
SmallVector<Value> acc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
||||
size_t resSize = acc.size();
|
||||
|
||||
// The resVals holds the final result of the DotOp.
|
||||
@@ -3719,38 +3699,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
auto bShape = bTensorTy.getShape();
|
||||
auto cShape = cTensorTy.getShape();
|
||||
|
||||
ValueTable has, hbs;
|
||||
int mShapePerCTA{-1}, nShapePerCTA{-1};
|
||||
int mSizePerThread{-1}, nSizePerThread{-1};
|
||||
ArrayRef<unsigned> aOrder, bOrder;
|
||||
Value llA, llB;
|
||||
BlockedEncodingAttr dLayout =
|
||||
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto order = dLayout.getOrder();
|
||||
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
||||
|
||||
DotOpFMAConversionHelper helper(dLayout);
|
||||
if (auto aDotOpLayout =
|
||||
aTensorTy.getEncoding()
|
||||
.dyn_cast<DotOperandEncodingAttr>()) { // get input from
|
||||
// convert_layout
|
||||
auto bDotOpLayout =
|
||||
bTensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
|
||||
auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
auto aDotOpLayout = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto bDotOpLayout = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
|
||||
assert(bLayout);
|
||||
llA = adaptor.a();
|
||||
llB = adaptor.b();
|
||||
} else if (auto aLayout =
|
||||
aTensorTy.getEncoding()
|
||||
.dyn_cast<SharedEncodingAttr>()) { // load input from smem
|
||||
auto bLayout = bTensorTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||
assert(bLayout);
|
||||
Value thread = getThreadId(rewriter, loc);
|
||||
llA = helper.loadA(A, adaptor.a(), dLayout, thread, loc, rewriter);
|
||||
llB = helper.loadB(B, adaptor.b(), dLayout, thread, loc, rewriter);
|
||||
}
|
||||
Value llA = adaptor.a();
|
||||
Value llB = adaptor.b();
|
||||
|
||||
auto sizePerThread = getSizePerThread(dLayout);
|
||||
auto shapePerCTA = getShapePerCTA(dLayout);
|
||||
@@ -3759,17 +3720,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
int M = aShape[0];
|
||||
int N = bShape[1];
|
||||
|
||||
mShapePerCTA = order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
mSizePerThread =
|
||||
int mShapePerCTA =
|
||||
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
int mSizePerThread =
|
||||
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||
nShapePerCTA = order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
nSizePerThread =
|
||||
int nShapePerCTA =
|
||||
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
int nSizePerThread =
|
||||
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||
|
||||
has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA, mSizePerThread,
|
||||
rewriter, loc);
|
||||
hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA, nSizePerThread,
|
||||
rewriter, loc);
|
||||
auto has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA,
|
||||
mSizePerThread, rewriter, loc);
|
||||
auto hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA,
|
||||
nSizePerThread, rewriter, loc);
|
||||
|
||||
SmallVector<Value> ret = cc;
|
||||
for (unsigned k = 0; k < K; k++) {
|
||||
@@ -3780,7 +3743,6 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
|
||||
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
|
||||
hbs[{n + nn, k}], ret[z]);
|
||||
|
||||
++z;
|
||||
}
|
||||
}
|
||||
@@ -4310,9 +4272,10 @@ struct ExpOpConversionApprox
|
||||
// For FP64 input, call __nv_expf for higher-precision calculation
|
||||
if (elemTy.getIntOrFloatBitWidth() == 64)
|
||||
return {};
|
||||
|
||||
const double log2e = 1.4426950408889634;
|
||||
Value prod =
|
||||
rewriter.create<LLVM::FMulOp>(loc, f32_ty, operands[0], f32_val(log2e));
|
||||
Value prod = fmul(f32_ty, operands[0], f32_val(log2e));
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
|
||||
auto output = ptxBuilder.newOperand("=f");
|
||||
|
@@ -31,6 +31,7 @@
|
||||
#include <numeric>
|
||||
|
||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
||||
// Operators
|
||||
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
|
||||
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
|
||||
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
||||
@@ -40,6 +41,7 @@
|
||||
#define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__)
|
||||
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
|
||||
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
|
||||
#define fmul(...) rewriter.create<LLVM::FMulOp>(loc, __VA_ARGS__)
|
||||
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
|
||||
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
|
||||
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
|
||||
@@ -90,6 +92,8 @@
|
||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||
|
||||
// Types
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define ui32_ty rewriter.getIntegerType(32, false)
|
||||
#define f16_ty rewriter.getF16Type()
|
||||
@@ -102,8 +106,9 @@
|
||||
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
|
||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
|
||||
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
|
||||
|
||||
// Creator for constant
|
||||
// Constants
|
||||
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
|
||||
#define int_val(width, val) \
|
||||
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
|
||||
|
@@ -36,7 +36,7 @@ namespace {
|
||||
class DecomposeDotOperand : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
DecomposeDotOperand(mlir::MLIRContext *context)
|
||||
explicit DecomposeDotOperand(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
@@ -84,7 +84,7 @@ public:
|
||||
// IIUC they are therefore not handled by DRR right now
|
||||
class SimplifyConversion : public mlir::RewritePattern {
|
||||
public:
|
||||
SimplifyConversion(mlir::MLIRContext *context)
|
||||
explicit SimplifyConversion(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
4, context) {}
|
||||
|
||||
@@ -219,8 +219,8 @@ public:
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
Attribute &ret) {
|
||||
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
Attribute &ret) {
|
||||
ret = targetEncoding;
|
||||
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||
ret = triton::gpu::SliceEncodingAttr::get(
|
||||
@@ -246,7 +246,7 @@ inline bool expensive_to_remat(Operation *op) {
|
||||
if (isa<scf::YieldOp, scf::ForOp>(op))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
}
|
||||
|
||||
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
BlockAndValueMapping &mapping) {
|
||||
@@ -276,7 +276,7 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
// are reachable from it without passing through any memory operation.
|
||||
class RematerializeBackward : public mlir::RewritePattern {
|
||||
public:
|
||||
RematerializeBackward(mlir::MLIRContext *context)
|
||||
explicit RematerializeBackward(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
2, context) {}
|
||||
|
||||
@@ -303,7 +303,7 @@ public:
|
||||
SetVector<Attribute> layout;
|
||||
llvm::MapVector<Value, Attribute> toConvert;
|
||||
std::vector<std::pair<Operation *, Attribute>> queue;
|
||||
queue.push_back({cvt, targetType.getEncoding()});
|
||||
queue.emplace_back(cvt, targetType.getEncoding());
|
||||
int numCvts = 1;
|
||||
while (!queue.empty()) {
|
||||
Operation *currOp;
|
||||
@@ -341,7 +341,7 @@ public:
|
||||
continue;
|
||||
// we add one expensive conversion for the current operand
|
||||
numCvts += 1;
|
||||
queue.push_back({opArgI, newEncoding});
|
||||
queue.emplace_back(opArgI, newEncoding);
|
||||
}
|
||||
}
|
||||
// if rematerialization would add more conversions than it removes
|
||||
@@ -351,8 +351,8 @@ public:
|
||||
|
||||
SmallVector<Value, 4> sortedValues;
|
||||
SetVector<Operation *> tmp;
|
||||
for (auto it = toConvert.begin(); it != toConvert.end(); ++it) {
|
||||
Value v = it->first;
|
||||
for (auto &item : toConvert) {
|
||||
Value v = item.first;
|
||||
if (v.getDefiningOp())
|
||||
tmp.insert(v.getDefiningOp());
|
||||
else
|
||||
@@ -393,7 +393,7 @@ public:
|
||||
|
||||
class MoveConvertOutOfLoop : public mlir::RewritePattern {
|
||||
public:
|
||||
MoveConvertOutOfLoop(mlir::MLIRContext *context)
|
||||
explicit MoveConvertOutOfLoop(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
|
||||
|
||||
SmallVector<Value, 4>
|
||||
@@ -406,7 +406,7 @@ public:
|
||||
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
|
||||
// Clone for loop
|
||||
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
||||
auto newForOp = rewriter.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), newInitArgs);
|
||||
newForOp->moveBefore(forOp);
|
||||
@@ -455,7 +455,7 @@ public:
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
auto iterArgs = forOp.getRegionIterArgs();
|
||||
for (auto iterArg : llvm::enumerate(iterArgs)) {
|
||||
for (const auto &iterArg : llvm::enumerate(iterArgs)) {
|
||||
// if (iterArg.index() != 1)
|
||||
// continue;
|
||||
// skip non-tensor types
|
||||
@@ -517,7 +517,7 @@ public:
|
||||
|
||||
class RematerializeForward : public mlir::RewritePattern {
|
||||
public:
|
||||
RematerializeForward(mlir::MLIRContext *context)
|
||||
explicit RematerializeForward(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
2, context) {}
|
||||
|
||||
@@ -584,7 +584,7 @@ public:
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
namespace {
|
||||
static int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
if (computeCapability < 80) {
|
||||
return 1;
|
||||
} else if (computeCapability < 90) {
|
||||
@@ -595,9 +595,7 @@ static int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
}
|
||||
}
|
||||
|
||||
static SmallVector<int64_t, 2>
|
||||
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
||||
int numWarps) {
|
||||
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
|
||||
if (version == 1)
|
||||
return {16, 16};
|
||||
else if (version == 2)
|
||||
@@ -608,12 +606,11 @@ mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp =
|
||||
mmaVersionToShapePerWarp(1, shape, numWarps);
|
||||
mmaVersionToShapePerWarp(1 /*version*/);
|
||||
bool changed = false;
|
||||
do {
|
||||
changed = false;
|
||||
@@ -669,7 +666,7 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
|
||||
class OptimizeBlockedToShared : public mlir::RewritePattern {
|
||||
public:
|
||||
OptimizeBlockedToShared(mlir::MLIRContext *context)
|
||||
explicit OptimizeBlockedToShared(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
@@ -717,7 +714,7 @@ public:
|
||||
|
||||
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
||||
public:
|
||||
OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
@@ -729,11 +726,12 @@ public:
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
// order
|
||||
ArrayRef<unsigned> order;
|
||||
if(auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
if (auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
order = srcBlockedLayout.getOrder();
|
||||
else if(auto srcSharedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
else if (auto srcSharedLayout =
|
||||
srcType.getEncoding()
|
||||
.dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
order = srcSharedLayout.getOrder();
|
||||
else
|
||||
return failure();
|
||||
@@ -742,20 +740,18 @@ public:
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!dstDotOperandLayout)
|
||||
return failure();
|
||||
unsigned opIdx = dstDotOperandLayout.getOpIdx();
|
||||
if(!dstDotOperandLayout.getIsMMAv1Row())
|
||||
if (!dstDotOperandLayout.getIsMMAv1Row())
|
||||
return failure();
|
||||
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if((order[0] == 1 && isMMAv1Row) ||
|
||||
(order[0] == 0 && !isMMAv1Row))
|
||||
bool isMMAv1Row =
|
||||
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
|
||||
return failure();
|
||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
|
||||
newIsRow);
|
||||
op->getContext(), dstDotOperandLayout.getOpIdx(),
|
||||
dstDotOperandLayout.getParent(), newIsRow);
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(),
|
||||
dstType.getElementType(), newDstEncoding);
|
||||
dstType.getShape(), dstType.getElementType(), newDstEncoding);
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newDstType, cvt.getOperand());
|
||||
rewriter.replaceOp(op, newCvt.getResult());
|
||||
@@ -763,7 +759,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
|
||||
@@ -777,7 +772,7 @@ public:
|
||||
int version, int numWarps) {
|
||||
switch (version) {
|
||||
case 1:
|
||||
return warpsPerTileV1(dotOp, shape, numWarps);
|
||||
return warpsPerTileV1(shape, numWarps);
|
||||
case 2:
|
||||
return warpsPerTileV2(dotOp, shape, numWarps);
|
||||
default:
|
||||
@@ -821,27 +816,31 @@ public:
|
||||
Value b = dotOp.b();
|
||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
|
||||
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
|
||||
auto oldAOrder = oldAType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
auto oldBOrder = oldBType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
Attribute isMMAv1RowA;
|
||||
Attribute isMMAv1RowB;
|
||||
if(version == 1){
|
||||
if (version == 1) {
|
||||
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
||||
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
||||
}
|
||||
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
|
||||
newRetType.getEncoding(),
|
||||
isMMAv1RowA));
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
|
||||
newRetType.getEncoding(),
|
||||
isMMAv1RowB));
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
|
||||
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
@@ -857,9 +856,8 @@ public:
|
||||
class FixupLoop : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
FixupLoop(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 2,
|
||||
context) {}
|
||||
explicit FixupLoop(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 2, context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
@@ -869,15 +867,15 @@ public:
|
||||
// Rewrite init argument
|
||||
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
||||
bool shouldRematerialize = false;
|
||||
for(size_t i = 0; i < newInitArgs.size(); i++){
|
||||
for (size_t i = 0; i < newInitArgs.size(); i++) {
|
||||
auto initArg = newInitArgs[i];
|
||||
auto regionArg = forOp.getRegionIterArgs()[i];
|
||||
if(newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()){
|
||||
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) {
|
||||
shouldRematerialize = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!shouldRematerialize)
|
||||
if (!shouldRematerialize)
|
||||
return failure();
|
||||
|
||||
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
||||
@@ -894,8 +892,6 @@ public:
|
||||
}
|
||||
rewriter.replaceOp(forOp, newForOp.getResults());
|
||||
return success();
|
||||
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user