[Triton-MLIR][BACKEND] some code clean on the backend (#978)

This commit is contained in:
Yan Chunwei
2022-12-12 17:46:16 +08:00
committed by GitHub
parent e5cfa0f633
commit 0cfe909df8
4 changed files with 97 additions and 137 deletions

View File

@@ -105,11 +105,9 @@ struct DotOpMmaV1ConversionHelper {
} }
// Get the number of fp16x2 elements for $a. // Get the number of fp16x2 elements for $a.
// \param shapeTransed: the shape or reordered shape if transpose needed. // \param shapeTransed: A's shape or reordered shape if transpose needed.
// \param orderTransed: the order or reordered order if transpose needed. // \param orderTransed: the order or reordered order if transpose needed.
unsigned getNumM(ArrayRef<int64_t> shapeTransed, unsigned getNumM(ArrayRef<int64_t> shapeTransed, bool isARow) const {
ArrayRef<unsigned> orderTransed) const {
bool isARow = orderTransed[0] != 0;
AParam param(isARow); AParam param(isARow);
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]); unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
@@ -117,11 +115,9 @@ struct DotOpMmaV1ConversionHelper {
} }
// Get the number of fp16x2 elements for $b. // Get the number of fp16x2 elements for $b.
// \param shapeTransed: the shape or reordered shape if transpose needed. // \param shapeTransed: B' shape or reordered shape if transpose needed.
// \param orderTransed: the order or reordered order if transpose needed. // \param orderTransed: the order or reordered order if transpose needed.
unsigned getNumN(ArrayRef<int64_t> shapeTransed, unsigned getNumN(ArrayRef<int64_t> shapeTransed, bool isBRow) const {
ArrayRef<unsigned> orderTransed) const {
bool isBRow = orderTransed[0] != 0;
BParam param(isBRow); BParam param(isBRow);
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]); unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
@@ -130,7 +126,7 @@ struct DotOpMmaV1ConversionHelper {
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed, int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const { ArrayRef<unsigned> orderTransed) const {
int numM = getNumM(shapeTransed, orderTransed); int numM = getNumM(shapeTransed, orderTransed[0] == 1);
int NK = shapeTransed[1]; int NK = shapeTransed[1];
// NOTE: We couldn't get the vec from the shared layout. // NOTE: We couldn't get the vec from the shared layout.
@@ -143,7 +139,7 @@ struct DotOpMmaV1ConversionHelper {
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed, int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const { ArrayRef<unsigned> orderTransed) const {
unsigned numN = getNumN(shapeTransed, orderTransed); unsigned numN = getNumN(shapeTransed, orderTransed[0] == 1);
int NK = shapeTransed[0]; int NK = shapeTransed[0];
// NOTE: We couldn't get the vec from the shared layout. // NOTE: We couldn't get the vec from the shared layout.
// int vecB = sharedLayout.getVec(); // int vecB = sharedLayout.getVec();
@@ -1451,7 +1447,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
} }
}; };
unsigned numM = getNumM(shape, order); unsigned numM = getNumM(shape, order[0] == 1);
for (unsigned k = 0; k < NK; k += 4) for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m) for (unsigned m = 0; m < numM / 2; ++m)
loadA(m, k); loadA(m, k);
@@ -1563,7 +1559,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
} }
}; };
unsigned numN = getNumN(shape, order); unsigned numN = getNumN(shape, order[0] == 1);
for (unsigned k = 0; k < NK; k += 4) for (unsigned k = 0; k < NK; k += 4)
for (unsigned n = 0; n < numN / 2; ++n) { for (unsigned n = 0; n < numN / 2; ++n) {
if (!hbs.count({n, k})) if (!hbs.count({n, k}))

View File

@@ -739,7 +739,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
auto tensorTy = resType.cast<RankedTensorType>(); auto tensorTy = resType.cast<RankedTensorType>();
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() || if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
tensorTy.getEncoding().isa<SliceEncodingAttr>()) { tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
auto tensorTy = resType.cast<RankedTensorType>();
auto srcType = typeConverter->convertType(elemType); auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(constVal, srcType); auto llSrc = bitcast(constVal, srcType);
size_t elemsPerThread = getElemsPerThread(tensorTy); size_t elemsPerThread = getElemsPerThread(tensorTy);
@@ -981,7 +980,7 @@ struct LoadOpConversion
size_t size = width / valueElemNbits; size_t size = width / valueElemNbits;
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size); auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy); Value v = undef(vecTy);
for (size_t s = 0; s < size; ++s) { for (size_t s = 0; s < size; ++s) {
Value falseVal = otherElems[vecStart + ii * size + s]; Value falseVal = otherElems[vecStart + ii * size + s];
Value sVal = createIndexAttrConstant( Value sVal = createIndexAttrConstant(
@@ -1118,7 +1117,7 @@ struct StoreOpConversion
SmallVector<std::pair<Value, std::string>> asmArgs; SmallVector<std::pair<Value, std::string>> asmArgs;
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
// llWord is a width-len composition // llWord is a width-len composition
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy); Value llWord = undef(wordTy);
// Insert each value element to the composition // Insert each value element to the composition
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) { for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx; const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
@@ -1129,10 +1128,7 @@ struct StoreOpConversion
elem = bitcast(elem, valueElemTy); elem = bitcast(elem, valueElemTy);
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx)); Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
llWord = llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
insert_element(wordTy, llWord, elem,
rewriter.create<LLVM::ConstantOp>(
loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx)));
} }
llWord = bitcast(llWord, valArgTy); llWord = bitcast(llWord, valArgTy);
std::string constraint = std::string constraint =
@@ -3570,43 +3566,27 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto ATensorTy = A.getType().cast<RankedTensorType>(); auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>(); auto BTensorTy = B.getType().cast<RankedTensorType>();
auto DTensorTy = D.getType().cast<RankedTensorType>(); auto DTensorTy = D.getType().cast<RankedTensorType>();
SmallVector<int> AShape(ATensorTy.getShape().begin(), auto AShape = ATensorTy.getShape();
ATensorTy.getShape().end()); auto BShape = BTensorTy.getShape();
SmallVector<int> BShape(BTensorTy.getShape().begin(),
BTensorTy.getShape().end());
auto DShape = DTensorTy.getShape(); auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA(); auto wpt = mmaLayout.getWarpsPerCTA();
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue(); bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue(); bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
// TODO[Superjomn]: ld.v4 is not supported.
isAVec4 = true;
isBVec4 = true;
int packSize0 = (isARow || isAVec4) ? 1 : 2;
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({2 * packSize0, 2 * packSize1, 1});
SmallVector<int> spw({fpw[0] * 4 * rep[0], fpw[1] * 4 * rep[1], 1});
Value loadedA = adaptor.a();
Value loadedB = adaptor.b();
Value loadedC = adaptor.c();
DotOpMmaV1ConversionHelper helper(mmaLayout); DotOpMmaV1ConversionHelper helper(mmaLayout);
unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]); unsigned numM = helper.getNumM(AShape, isARow);
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[1]); unsigned numN = helper.getNumN(BShape, isBRow);
unsigned NK = AShape[1]; unsigned NK = AShape[1];
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter); auto has = helper.extractLoadedOperand(adaptor.a(), NK, rewriter);
auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter); auto hbs = helper.extractLoadedOperand(adaptor.b(), NK, rewriter);
// Initialize accumulators with external values, the acc holds the accumulator // Initialize accumulators with external values, the acc holds the accumulator
// value that is shared between the MMA instructions inside a DotOp, we can // value that is shared between the MMA instructions inside a DotOp, we can
// call the order of the values the accumulator-internal order. // call the order of the values the accumulator-internal order.
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter); SmallVector<Value> acc = getElementsFromStruct(loc, adaptor.c(), rewriter);
size_t resSize = acc.size(); size_t resSize = acc.size();
// The resVals holds the final result of the DotOp. // The resVals holds the final result of the DotOp.
@@ -3719,38 +3699,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
auto bShape = bTensorTy.getShape(); auto bShape = bTensorTy.getShape();
auto cShape = cTensorTy.getShape(); auto cShape = cTensorTy.getShape();
ValueTable has, hbs;
int mShapePerCTA{-1}, nShapePerCTA{-1};
int mSizePerThread{-1}, nSizePerThread{-1};
ArrayRef<unsigned> aOrder, bOrder;
Value llA, llB;
BlockedEncodingAttr dLayout = BlockedEncodingAttr dLayout =
dTensorTy.getEncoding().cast<BlockedEncodingAttr>(); dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
auto order = dLayout.getOrder(); auto order = dLayout.getOrder();
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter); auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
DotOpFMAConversionHelper helper(dLayout); DotOpFMAConversionHelper helper(dLayout);
if (auto aDotOpLayout = auto aDotOpLayout = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
aTensorTy.getEncoding() auto bDotOpLayout = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
.dyn_cast<DotOperandEncodingAttr>()) { // get input from auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
// convert_layout auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
auto bDotOpLayout =
bTensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
assert(bLayout); Value llA = adaptor.a();
llA = adaptor.a(); Value llB = adaptor.b();
llB = adaptor.b();
} else if (auto aLayout =
aTensorTy.getEncoding()
.dyn_cast<SharedEncodingAttr>()) { // load input from smem
auto bLayout = bTensorTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(bLayout);
Value thread = getThreadId(rewriter, loc);
llA = helper.loadA(A, adaptor.a(), dLayout, thread, loc, rewriter);
llB = helper.loadB(B, adaptor.b(), dLayout, thread, loc, rewriter);
}
auto sizePerThread = getSizePerThread(dLayout); auto sizePerThread = getSizePerThread(dLayout);
auto shapePerCTA = getShapePerCTA(dLayout); auto shapePerCTA = getShapePerCTA(dLayout);
@@ -3759,17 +3720,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
int M = aShape[0]; int M = aShape[0];
int N = bShape[1]; int N = bShape[1];
mShapePerCTA = order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; int mShapePerCTA =
mSizePerThread = order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int mSizePerThread =
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
nShapePerCTA = order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; int nShapePerCTA =
nSizePerThread = order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int nSizePerThread =
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA, mSizePerThread, auto has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA,
rewriter, loc); mSizePerThread, rewriter, loc);
hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA, nSizePerThread, auto hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA,
rewriter, loc); nSizePerThread, rewriter, loc);
SmallVector<Value> ret = cc; SmallVector<Value> ret = cc;
for (unsigned k = 0; k < K; k++) { for (unsigned k = 0; k < K; k++) {
@@ -3780,7 +3743,6 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
for (unsigned nn = 0; nn < nSizePerThread; ++nn) { for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}], ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
hbs[{n + nn, k}], ret[z]); hbs[{n + nn, k}], ret[z]);
++z; ++z;
} }
} }
@@ -4310,9 +4272,10 @@ struct ExpOpConversionApprox
// For FP64 input, call __nv_expf for higher-precision calculation // For FP64 input, call __nv_expf for higher-precision calculation
if (elemTy.getIntOrFloatBitWidth() == 64) if (elemTy.getIntOrFloatBitWidth() == 64)
return {}; return {};
const double log2e = 1.4426950408889634; const double log2e = 1.4426950408889634;
Value prod = Value prod = fmul(f32_ty, operands[0], f32_val(log2e));
rewriter.create<LLVM::FMulOp>(loc, f32_ty, operands[0], f32_val(log2e));
PTXBuilder ptxBuilder; PTXBuilder ptxBuilder;
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32"); auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
auto output = ptxBuilder.newOperand("=f"); auto output = ptxBuilder.newOperand("=f");

View File

@@ -31,6 +31,7 @@
#include <numeric> #include <numeric>
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
// Operators
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__) #define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__) #define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__) #define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
@@ -40,6 +41,7 @@
#define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__) #define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__)
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__) #define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__) #define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
#define fmul(...) rewriter.create<LLVM::FMulOp>(loc, __VA_ARGS__)
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__) #define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__) #define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__) #define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
@@ -90,6 +92,8 @@
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__) #define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc) #define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__) #define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
// Types
#define i32_ty rewriter.getIntegerType(32) #define i32_ty rewriter.getIntegerType(32)
#define ui32_ty rewriter.getIntegerType(32, false) #define ui32_ty rewriter.getIntegerType(32, false)
#define f16_ty rewriter.getF16Type() #define f16_ty rewriter.getF16Type()
@@ -102,8 +106,9 @@
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__) #define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) #define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__) #define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
// Creator for constant // Constants
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) #define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
#define int_val(width, val) \ #define int_val(width, val) \
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val) LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)

View File

@@ -36,7 +36,7 @@ namespace {
class DecomposeDotOperand : public mlir::RewritePattern { class DecomposeDotOperand : public mlir::RewritePattern {
public: public:
DecomposeDotOperand(mlir::MLIRContext *context) explicit DecomposeDotOperand(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {} 1, context) {}
@@ -84,7 +84,7 @@ public:
// IIUC they are therefore not handled by DRR right now // IIUC they are therefore not handled by DRR right now
class SimplifyConversion : public mlir::RewritePattern { class SimplifyConversion : public mlir::RewritePattern {
public: public:
SimplifyConversion(mlir::MLIRContext *context) explicit SimplifyConversion(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
4, context) {} 4, context) {}
@@ -219,8 +219,8 @@ public:
// //
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op, LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret) { Attribute &ret) {
ret = targetEncoding; ret = targetEncoding;
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) { if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
ret = triton::gpu::SliceEncodingAttr::get( ret = triton::gpu::SliceEncodingAttr::get(
@@ -246,7 +246,7 @@ inline bool expensive_to_remat(Operation *op) {
if (isa<scf::YieldOp, scf::ForOp>(op)) if (isa<scf::YieldOp, scf::ForOp>(op))
return true; return true;
return false; return false;
}; }
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op, Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
BlockAndValueMapping &mapping) { BlockAndValueMapping &mapping) {
@@ -276,7 +276,7 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
// are reachable from it without passing through any memory operation. // are reachable from it without passing through any memory operation.
class RematerializeBackward : public mlir::RewritePattern { class RematerializeBackward : public mlir::RewritePattern {
public: public:
RematerializeBackward(mlir::MLIRContext *context) explicit RematerializeBackward(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
2, context) {} 2, context) {}
@@ -303,7 +303,7 @@ public:
SetVector<Attribute> layout; SetVector<Attribute> layout;
llvm::MapVector<Value, Attribute> toConvert; llvm::MapVector<Value, Attribute> toConvert;
std::vector<std::pair<Operation *, Attribute>> queue; std::vector<std::pair<Operation *, Attribute>> queue;
queue.push_back({cvt, targetType.getEncoding()}); queue.emplace_back(cvt, targetType.getEncoding());
int numCvts = 1; int numCvts = 1;
while (!queue.empty()) { while (!queue.empty()) {
Operation *currOp; Operation *currOp;
@@ -341,7 +341,7 @@ public:
continue; continue;
// we add one expensive conversion for the current operand // we add one expensive conversion for the current operand
numCvts += 1; numCvts += 1;
queue.push_back({opArgI, newEncoding}); queue.emplace_back(opArgI, newEncoding);
} }
} }
// if rematerialization would add more conversions than it removes // if rematerialization would add more conversions than it removes
@@ -351,8 +351,8 @@ public:
SmallVector<Value, 4> sortedValues; SmallVector<Value, 4> sortedValues;
SetVector<Operation *> tmp; SetVector<Operation *> tmp;
for (auto it = toConvert.begin(); it != toConvert.end(); ++it) { for (auto &item : toConvert) {
Value v = it->first; Value v = item.first;
if (v.getDefiningOp()) if (v.getDefiningOp())
tmp.insert(v.getDefiningOp()); tmp.insert(v.getDefiningOp());
else else
@@ -393,7 +393,7 @@ public:
class MoveConvertOutOfLoop : public mlir::RewritePattern { class MoveConvertOutOfLoop : public mlir::RewritePattern {
public: public:
MoveConvertOutOfLoop(mlir::MLIRContext *context) explicit MoveConvertOutOfLoop(mlir::MLIRContext *context)
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {} : mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
SmallVector<Value, 4> SmallVector<Value, 4>
@@ -406,7 +406,7 @@ public:
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>( newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
newInitArgs[i].getLoc(), newType, newInitArgs[i]); newInitArgs[i].getLoc(), newType, newInitArgs[i]);
// Clone for loop // Clone for loop
scf::ForOp newForOp = rewriter.create<scf::ForOp>( auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newInitArgs); forOp.getStep(), newInitArgs);
newForOp->moveBefore(forOp); newForOp->moveBefore(forOp);
@@ -455,7 +455,7 @@ public:
mlir::PatternRewriter &rewriter) const override { mlir::PatternRewriter &rewriter) const override {
auto forOp = cast<scf::ForOp>(op); auto forOp = cast<scf::ForOp>(op);
auto iterArgs = forOp.getRegionIterArgs(); auto iterArgs = forOp.getRegionIterArgs();
for (auto iterArg : llvm::enumerate(iterArgs)) { for (const auto &iterArg : llvm::enumerate(iterArgs)) {
// if (iterArg.index() != 1) // if (iterArg.index() != 1)
// continue; // continue;
// skip non-tensor types // skip non-tensor types
@@ -517,7 +517,7 @@ public:
class RematerializeForward : public mlir::RewritePattern { class RematerializeForward : public mlir::RewritePattern {
public: public:
RematerializeForward(mlir::MLIRContext *context) explicit RematerializeForward(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
2, context) {} 2, context) {}
@@ -584,7 +584,7 @@ public:
// //
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
namespace { namespace {
static int computeCapabilityToMMAVersion(int computeCapability) { int computeCapabilityToMMAVersion(int computeCapability) {
if (computeCapability < 80) { if (computeCapability < 80) {
return 1; return 1;
} else if (computeCapability < 90) { } else if (computeCapability < 90) {
@@ -595,9 +595,7 @@ static int computeCapabilityToMMAVersion(int computeCapability) {
} }
} }
static SmallVector<int64_t, 2> SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
int numWarps) {
if (version == 1) if (version == 1)
return {16, 16}; return {16, 16};
else if (version == 2) else if (version == 2)
@@ -608,12 +606,11 @@ mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
} }
} }
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp, SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
const ArrayRef<int64_t> shape,
int numWarps) { int numWarps) {
SmallVector<unsigned, 2> ret = {1, 1}; SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(1, shape, numWarps); mmaVersionToShapePerWarp(1 /*version*/);
bool changed = false; bool changed = false;
do { do {
changed = false; changed = false;
@@ -669,7 +666,7 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
class OptimizeBlockedToShared : public mlir::RewritePattern { class OptimizeBlockedToShared : public mlir::RewritePattern {
public: public:
OptimizeBlockedToShared(mlir::MLIRContext *context) explicit OptimizeBlockedToShared(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, : RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {} context) {}
@@ -717,7 +714,7 @@ public:
class OptimizeConvertToDotOperand : public mlir::RewritePattern { class OptimizeConvertToDotOperand : public mlir::RewritePattern {
public: public:
OptimizeConvertToDotOperand(mlir::MLIRContext *context) explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, : RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {} context) {}
@@ -729,11 +726,12 @@ public:
auto dstType = cvt.getResult().getType().cast<RankedTensorType>(); auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order // order
ArrayRef<unsigned> order; ArrayRef<unsigned> order;
if(auto srcBlockedLayout = if (auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>()) srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder(); order = srcBlockedLayout.getOrder();
else if(auto srcSharedLayout = else if (auto srcSharedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>()) srcType.getEncoding()
.dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder(); order = srcSharedLayout.getOrder();
else else
return failure(); return failure();
@@ -742,20 +740,18 @@ public:
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>(); dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout) if (!dstDotOperandLayout)
return failure(); return failure();
unsigned opIdx = dstDotOperandLayout.getOpIdx(); if (!dstDotOperandLayout.getIsMMAv1Row())
if(!dstDotOperandLayout.getIsMMAv1Row())
return failure(); return failure();
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue(); bool isMMAv1Row =
if((order[0] == 1 && isMMAv1Row) || dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
(order[0] == 0 && !isMMAv1Row)) if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
return failure(); return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row); auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get( auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(), op->getContext(), dstDotOperandLayout.getOpIdx(),
newIsRow); dstDotOperandLayout.getParent(), newIsRow);
auto newDstType = RankedTensorType::get( auto newDstType = RankedTensorType::get(
dstType.getShape(), dstType.getShape(), dstType.getElementType(), newDstEncoding);
dstType.getElementType(), newDstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>( auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newDstType, cvt.getOperand()); op->getLoc(), newDstType, cvt.getOperand());
rewriter.replaceOp(op, newCvt.getResult()); rewriter.replaceOp(op, newCvt.getResult());
@@ -763,7 +759,6 @@ public:
} }
}; };
class BlockedToMMA : public mlir::RewritePattern { class BlockedToMMA : public mlir::RewritePattern {
int computeCapability; int computeCapability;
@@ -777,7 +772,7 @@ public:
int version, int numWarps) { int version, int numWarps) {
switch (version) { switch (version) {
case 1: case 1:
return warpsPerTileV1(dotOp, shape, numWarps); return warpsPerTileV1(shape, numWarps);
case 2: case 2:
return warpsPerTileV2(dotOp, shape, numWarps); return warpsPerTileV2(dotOp, shape, numWarps);
default: default:
@@ -821,27 +816,31 @@ public:
Value b = dotOp.b(); Value b = dotOp.b();
auto oldAType = a.getType().cast<RankedTensorType>(); auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>(); auto oldBType = b.getType().cast<RankedTensorType>();
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>() auto oldAOrder = oldAType.getEncoding()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder(); .cast<triton::gpu::DotOperandEncodingAttr>()
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>() .getParent()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder(); .cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
auto oldBOrder = oldBType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
Attribute isMMAv1RowA; Attribute isMMAv1RowA;
Attribute isMMAv1RowB; Attribute isMMAv1RowB;
if(version == 1){ if (version == 1) {
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1); isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1); isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
} }
auto newAType = RankedTensorType::get( auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(), oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0, triton::gpu::DotOperandEncodingAttr::get(
newRetType.getEncoding(), oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
isMMAv1RowA));
auto newBType = RankedTensorType::get( auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(), oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1, triton::gpu::DotOperandEncodingAttr::get(
newRetType.getEncoding(), oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
isMMAv1RowB));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a); a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b); b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
@@ -857,9 +856,8 @@ public:
class FixupLoop : public mlir::RewritePattern { class FixupLoop : public mlir::RewritePattern {
public: public:
FixupLoop(mlir::MLIRContext *context) explicit FixupLoop(mlir::MLIRContext *context)
: mlir::RewritePattern(scf::ForOp::getOperationName(), 2, : mlir::RewritePattern(scf::ForOp::getOperationName(), 2, context) {}
context) {}
mlir::LogicalResult mlir::LogicalResult
matchAndRewrite(mlir::Operation *op, matchAndRewrite(mlir::Operation *op,
@@ -869,17 +867,17 @@ public:
// Rewrite init argument // Rewrite init argument
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs(); SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
bool shouldRematerialize = false; bool shouldRematerialize = false;
for(size_t i = 0; i < newInitArgs.size(); i++){ for (size_t i = 0; i < newInitArgs.size(); i++) {
auto initArg = newInitArgs[i]; auto initArg = newInitArgs[i];
auto regionArg = forOp.getRegionIterArgs()[i]; auto regionArg = forOp.getRegionIterArgs()[i];
if(newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()){ if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) {
shouldRematerialize = true; shouldRematerialize = true;
break; break;
} }
} }
if(!shouldRematerialize) if (!shouldRematerialize)
return failure(); return failure();
scf::ForOp newForOp = rewriter.create<scf::ForOp>( scf::ForOp newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newInitArgs); forOp.getStep(), newInitArgs);
@@ -894,8 +892,6 @@ public:
} }
rewriter.replaceOp(forOp, newForOp.getResults()); rewriter.replaceOp(forOp, newForOp.getResults());
return success(); return success();
} }
}; };