[Triton-MLIR][BACKEND] some code clean on the backend (#978)
This commit is contained in:
@@ -105,11 +105,9 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the number of fp16x2 elements for $a.
|
// Get the number of fp16x2 elements for $a.
|
||||||
// \param shapeTransed: the shape or reordered shape if transpose needed.
|
// \param shapeTransed: A's shape or reordered shape if transpose needed.
|
||||||
// \param orderTransed: the order or reordered order if transpose needed.
|
// \param orderTransed: the order or reordered order if transpose needed.
|
||||||
unsigned getNumM(ArrayRef<int64_t> shapeTransed,
|
unsigned getNumM(ArrayRef<int64_t> shapeTransed, bool isARow) const {
|
||||||
ArrayRef<unsigned> orderTransed) const {
|
|
||||||
bool isARow = orderTransed[0] != 0;
|
|
||||||
AParam param(isARow);
|
AParam param(isARow);
|
||||||
|
|
||||||
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
|
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
|
||||||
@@ -117,11 +115,9 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the number of fp16x2 elements for $b.
|
// Get the number of fp16x2 elements for $b.
|
||||||
// \param shapeTransed: the shape or reordered shape if transpose needed.
|
// \param shapeTransed: B' shape or reordered shape if transpose needed.
|
||||||
// \param orderTransed: the order or reordered order if transpose needed.
|
// \param orderTransed: the order or reordered order if transpose needed.
|
||||||
unsigned getNumN(ArrayRef<int64_t> shapeTransed,
|
unsigned getNumN(ArrayRef<int64_t> shapeTransed, bool isBRow) const {
|
||||||
ArrayRef<unsigned> orderTransed) const {
|
|
||||||
bool isBRow = orderTransed[0] != 0;
|
|
||||||
BParam param(isBRow);
|
BParam param(isBRow);
|
||||||
|
|
||||||
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
|
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
|
||||||
@@ -130,7 +126,7 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
|
|
||||||
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
|
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
|
||||||
ArrayRef<unsigned> orderTransed) const {
|
ArrayRef<unsigned> orderTransed) const {
|
||||||
int numM = getNumM(shapeTransed, orderTransed);
|
int numM = getNumM(shapeTransed, orderTransed[0] == 1);
|
||||||
int NK = shapeTransed[1];
|
int NK = shapeTransed[1];
|
||||||
|
|
||||||
// NOTE: We couldn't get the vec from the shared layout.
|
// NOTE: We couldn't get the vec from the shared layout.
|
||||||
@@ -143,7 +139,7 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
|
|
||||||
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
|
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
|
||||||
ArrayRef<unsigned> orderTransed) const {
|
ArrayRef<unsigned> orderTransed) const {
|
||||||
unsigned numN = getNumN(shapeTransed, orderTransed);
|
unsigned numN = getNumN(shapeTransed, orderTransed[0] == 1);
|
||||||
int NK = shapeTransed[0];
|
int NK = shapeTransed[0];
|
||||||
// NOTE: We couldn't get the vec from the shared layout.
|
// NOTE: We couldn't get the vec from the shared layout.
|
||||||
// int vecB = sharedLayout.getVec();
|
// int vecB = sharedLayout.getVec();
|
||||||
@@ -1451,7 +1447,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
unsigned numM = getNumM(shape, order);
|
unsigned numM = getNumM(shape, order[0] == 1);
|
||||||
for (unsigned k = 0; k < NK; k += 4)
|
for (unsigned k = 0; k < NK; k += 4)
|
||||||
for (unsigned m = 0; m < numM / 2; ++m)
|
for (unsigned m = 0; m < numM / 2; ++m)
|
||||||
loadA(m, k);
|
loadA(m, k);
|
||||||
@@ -1563,7 +1559,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
unsigned numN = getNumN(shape, order);
|
unsigned numN = getNumN(shape, order[0] == 1);
|
||||||
for (unsigned k = 0; k < NK; k += 4)
|
for (unsigned k = 0; k < NK; k += 4)
|
||||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||||
if (!hbs.count({n, k}))
|
if (!hbs.count({n, k}))
|
||||||
|
@@ -739,7 +739,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
|||||||
auto tensorTy = resType.cast<RankedTensorType>();
|
auto tensorTy = resType.cast<RankedTensorType>();
|
||||||
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
|
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
|
||||||
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
|
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
|
||||||
auto tensorTy = resType.cast<RankedTensorType>();
|
|
||||||
auto srcType = typeConverter->convertType(elemType);
|
auto srcType = typeConverter->convertType(elemType);
|
||||||
auto llSrc = bitcast(constVal, srcType);
|
auto llSrc = bitcast(constVal, srcType);
|
||||||
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
||||||
@@ -981,7 +980,7 @@ struct LoadOpConversion
|
|||||||
size_t size = width / valueElemNbits;
|
size_t size = width / valueElemNbits;
|
||||||
|
|
||||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
||||||
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
Value v = undef(vecTy);
|
||||||
for (size_t s = 0; s < size; ++s) {
|
for (size_t s = 0; s < size; ++s) {
|
||||||
Value falseVal = otherElems[vecStart + ii * size + s];
|
Value falseVal = otherElems[vecStart + ii * size + s];
|
||||||
Value sVal = createIndexAttrConstant(
|
Value sVal = createIndexAttrConstant(
|
||||||
@@ -1118,7 +1117,7 @@ struct StoreOpConversion
|
|||||||
SmallVector<std::pair<Value, std::string>> asmArgs;
|
SmallVector<std::pair<Value, std::string>> asmArgs;
|
||||||
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||||
// llWord is a width-len composition
|
// llWord is a width-len composition
|
||||||
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
Value llWord = undef(wordTy);
|
||||||
// Insert each value element to the composition
|
// Insert each value element to the composition
|
||||||
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
|
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
|
||||||
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
|
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
|
||||||
@@ -1129,10 +1128,7 @@ struct StoreOpConversion
|
|||||||
elem = bitcast(elem, valueElemTy);
|
elem = bitcast(elem, valueElemTy);
|
||||||
|
|
||||||
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
|
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
|
||||||
llWord =
|
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
|
||||||
insert_element(wordTy, llWord, elem,
|
|
||||||
rewriter.create<LLVM::ConstantOp>(
|
|
||||||
loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx)));
|
|
||||||
}
|
}
|
||||||
llWord = bitcast(llWord, valArgTy);
|
llWord = bitcast(llWord, valArgTy);
|
||||||
std::string constraint =
|
std::string constraint =
|
||||||
@@ -3570,43 +3566,27 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
|||||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||||
auto DTensorTy = D.getType().cast<RankedTensorType>();
|
auto DTensorTy = D.getType().cast<RankedTensorType>();
|
||||||
SmallVector<int> AShape(ATensorTy.getShape().begin(),
|
auto AShape = ATensorTy.getShape();
|
||||||
ATensorTy.getShape().end());
|
auto BShape = BTensorTy.getShape();
|
||||||
SmallVector<int> BShape(BTensorTy.getShape().begin(),
|
|
||||||
BTensorTy.getShape().end());
|
|
||||||
auto DShape = DTensorTy.getShape();
|
auto DShape = DTensorTy.getShape();
|
||||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||||
|
|
||||||
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||||
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||||
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
|
||||||
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
|
|
||||||
// TODO[Superjomn]: ld.v4 is not supported.
|
|
||||||
isAVec4 = true;
|
|
||||||
isBVec4 = true;
|
|
||||||
|
|
||||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
|
||||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
|
||||||
SmallVector<int> fpw({2, 2, 1});
|
|
||||||
SmallVector<int> rep({2 * packSize0, 2 * packSize1, 1});
|
|
||||||
SmallVector<int> spw({fpw[0] * 4 * rep[0], fpw[1] * 4 * rep[1], 1});
|
|
||||||
|
|
||||||
Value loadedA = adaptor.a();
|
|
||||||
Value loadedB = adaptor.b();
|
|
||||||
Value loadedC = adaptor.c();
|
|
||||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||||
|
|
||||||
unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]);
|
unsigned numM = helper.getNumM(AShape, isARow);
|
||||||
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[1]);
|
unsigned numN = helper.getNumN(BShape, isBRow);
|
||||||
unsigned NK = AShape[1];
|
unsigned NK = AShape[1];
|
||||||
|
|
||||||
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter);
|
auto has = helper.extractLoadedOperand(adaptor.a(), NK, rewriter);
|
||||||
auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter);
|
auto hbs = helper.extractLoadedOperand(adaptor.b(), NK, rewriter);
|
||||||
|
|
||||||
// Initialize accumulators with external values, the acc holds the accumulator
|
// Initialize accumulators with external values, the acc holds the accumulator
|
||||||
// value that is shared between the MMA instructions inside a DotOp, we can
|
// value that is shared between the MMA instructions inside a DotOp, we can
|
||||||
// call the order of the values the accumulator-internal order.
|
// call the order of the values the accumulator-internal order.
|
||||||
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
|
SmallVector<Value> acc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
||||||
size_t resSize = acc.size();
|
size_t resSize = acc.size();
|
||||||
|
|
||||||
// The resVals holds the final result of the DotOp.
|
// The resVals holds the final result of the DotOp.
|
||||||
@@ -3719,38 +3699,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
|||||||
auto bShape = bTensorTy.getShape();
|
auto bShape = bTensorTy.getShape();
|
||||||
auto cShape = cTensorTy.getShape();
|
auto cShape = cTensorTy.getShape();
|
||||||
|
|
||||||
ValueTable has, hbs;
|
|
||||||
int mShapePerCTA{-1}, nShapePerCTA{-1};
|
|
||||||
int mSizePerThread{-1}, nSizePerThread{-1};
|
|
||||||
ArrayRef<unsigned> aOrder, bOrder;
|
|
||||||
Value llA, llB;
|
|
||||||
BlockedEncodingAttr dLayout =
|
BlockedEncodingAttr dLayout =
|
||||||
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||||
auto order = dLayout.getOrder();
|
auto order = dLayout.getOrder();
|
||||||
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
||||||
|
|
||||||
DotOpFMAConversionHelper helper(dLayout);
|
DotOpFMAConversionHelper helper(dLayout);
|
||||||
if (auto aDotOpLayout =
|
auto aDotOpLayout = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||||
aTensorTy.getEncoding()
|
auto bDotOpLayout = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||||
.dyn_cast<DotOperandEncodingAttr>()) { // get input from
|
auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||||
// convert_layout
|
auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||||
auto bDotOpLayout =
|
|
||||||
bTensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
|
|
||||||
auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
|
||||||
auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
|
||||||
|
|
||||||
assert(bLayout);
|
Value llA = adaptor.a();
|
||||||
llA = adaptor.a();
|
Value llB = adaptor.b();
|
||||||
llB = adaptor.b();
|
|
||||||
} else if (auto aLayout =
|
|
||||||
aTensorTy.getEncoding()
|
|
||||||
.dyn_cast<SharedEncodingAttr>()) { // load input from smem
|
|
||||||
auto bLayout = bTensorTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
|
||||||
assert(bLayout);
|
|
||||||
Value thread = getThreadId(rewriter, loc);
|
|
||||||
llA = helper.loadA(A, adaptor.a(), dLayout, thread, loc, rewriter);
|
|
||||||
llB = helper.loadB(B, adaptor.b(), dLayout, thread, loc, rewriter);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto sizePerThread = getSizePerThread(dLayout);
|
auto sizePerThread = getSizePerThread(dLayout);
|
||||||
auto shapePerCTA = getShapePerCTA(dLayout);
|
auto shapePerCTA = getShapePerCTA(dLayout);
|
||||||
@@ -3759,17 +3720,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
|||||||
int M = aShape[0];
|
int M = aShape[0];
|
||||||
int N = bShape[1];
|
int N = bShape[1];
|
||||||
|
|
||||||
mShapePerCTA = order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
int mShapePerCTA =
|
||||||
mSizePerThread =
|
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||||
|
int mSizePerThread =
|
||||||
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||||
nShapePerCTA = order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
int nShapePerCTA =
|
||||||
nSizePerThread =
|
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||||
|
int nSizePerThread =
|
||||||
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||||
|
|
||||||
has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA, mSizePerThread,
|
auto has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA,
|
||||||
rewriter, loc);
|
mSizePerThread, rewriter, loc);
|
||||||
hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA, nSizePerThread,
|
auto hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA,
|
||||||
rewriter, loc);
|
nSizePerThread, rewriter, loc);
|
||||||
|
|
||||||
SmallVector<Value> ret = cc;
|
SmallVector<Value> ret = cc;
|
||||||
for (unsigned k = 0; k < K; k++) {
|
for (unsigned k = 0; k < K; k++) {
|
||||||
@@ -3780,7 +3743,6 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
|||||||
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
|
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
|
||||||
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
|
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
|
||||||
hbs[{n + nn, k}], ret[z]);
|
hbs[{n + nn, k}], ret[z]);
|
||||||
|
|
||||||
++z;
|
++z;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -4310,9 +4272,10 @@ struct ExpOpConversionApprox
|
|||||||
// For FP64 input, call __nv_expf for higher-precision calculation
|
// For FP64 input, call __nv_expf for higher-precision calculation
|
||||||
if (elemTy.getIntOrFloatBitWidth() == 64)
|
if (elemTy.getIntOrFloatBitWidth() == 64)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
const double log2e = 1.4426950408889634;
|
const double log2e = 1.4426950408889634;
|
||||||
Value prod =
|
Value prod = fmul(f32_ty, operands[0], f32_val(log2e));
|
||||||
rewriter.create<LLVM::FMulOp>(loc, f32_ty, operands[0], f32_val(log2e));
|
|
||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilder;
|
||||||
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
|
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
|
||||||
auto output = ptxBuilder.newOperand("=f");
|
auto output = ptxBuilder.newOperand("=f");
|
||||||
|
@@ -31,6 +31,7 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
||||||
|
// Operators
|
||||||
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
|
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
|
||||||
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
|
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
|
||||||
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
||||||
@@ -40,6 +41,7 @@
|
|||||||
#define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__)
|
#define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__)
|
||||||
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
|
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
|
||||||
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
|
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
|
||||||
|
#define fmul(...) rewriter.create<LLVM::FMulOp>(loc, __VA_ARGS__)
|
||||||
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
|
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
|
||||||
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
|
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
|
||||||
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
|
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
|
||||||
@@ -90,6 +92,8 @@
|
|||||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||||
|
|
||||||
|
// Types
|
||||||
#define i32_ty rewriter.getIntegerType(32)
|
#define i32_ty rewriter.getIntegerType(32)
|
||||||
#define ui32_ty rewriter.getIntegerType(32, false)
|
#define ui32_ty rewriter.getIntegerType(32, false)
|
||||||
#define f16_ty rewriter.getF16Type()
|
#define f16_ty rewriter.getF16Type()
|
||||||
@@ -102,8 +106,9 @@
|
|||||||
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
|
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
|
||||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
|
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
|
||||||
|
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
|
||||||
|
|
||||||
// Creator for constant
|
// Constants
|
||||||
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
|
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
|
||||||
#define int_val(width, val) \
|
#define int_val(width, val) \
|
||||||
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
|
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
|
||||||
|
@@ -36,7 +36,7 @@ namespace {
|
|||||||
class DecomposeDotOperand : public mlir::RewritePattern {
|
class DecomposeDotOperand : public mlir::RewritePattern {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
DecomposeDotOperand(mlir::MLIRContext *context)
|
explicit DecomposeDotOperand(mlir::MLIRContext *context)
|
||||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||||
1, context) {}
|
1, context) {}
|
||||||
|
|
||||||
@@ -84,7 +84,7 @@ public:
|
|||||||
// IIUC they are therefore not handled by DRR right now
|
// IIUC they are therefore not handled by DRR right now
|
||||||
class SimplifyConversion : public mlir::RewritePattern {
|
class SimplifyConversion : public mlir::RewritePattern {
|
||||||
public:
|
public:
|
||||||
SimplifyConversion(mlir::MLIRContext *context)
|
explicit SimplifyConversion(mlir::MLIRContext *context)
|
||||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||||
4, context) {}
|
4, context) {}
|
||||||
|
|
||||||
@@ -219,8 +219,8 @@ public:
|
|||||||
//
|
//
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||||
Attribute &ret) {
|
Attribute &ret) {
|
||||||
ret = targetEncoding;
|
ret = targetEncoding;
|
||||||
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
|
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||||
ret = triton::gpu::SliceEncodingAttr::get(
|
ret = triton::gpu::SliceEncodingAttr::get(
|
||||||
@@ -246,7 +246,7 @@ inline bool expensive_to_remat(Operation *op) {
|
|||||||
if (isa<scf::YieldOp, scf::ForOp>(op))
|
if (isa<scf::YieldOp, scf::ForOp>(op))
|
||||||
return true;
|
return true;
|
||||||
return false;
|
return false;
|
||||||
};
|
}
|
||||||
|
|
||||||
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||||
BlockAndValueMapping &mapping) {
|
BlockAndValueMapping &mapping) {
|
||||||
@@ -276,7 +276,7 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
|||||||
// are reachable from it without passing through any memory operation.
|
// are reachable from it without passing through any memory operation.
|
||||||
class RematerializeBackward : public mlir::RewritePattern {
|
class RematerializeBackward : public mlir::RewritePattern {
|
||||||
public:
|
public:
|
||||||
RematerializeBackward(mlir::MLIRContext *context)
|
explicit RematerializeBackward(mlir::MLIRContext *context)
|
||||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||||
2, context) {}
|
2, context) {}
|
||||||
|
|
||||||
@@ -303,7 +303,7 @@ public:
|
|||||||
SetVector<Attribute> layout;
|
SetVector<Attribute> layout;
|
||||||
llvm::MapVector<Value, Attribute> toConvert;
|
llvm::MapVector<Value, Attribute> toConvert;
|
||||||
std::vector<std::pair<Operation *, Attribute>> queue;
|
std::vector<std::pair<Operation *, Attribute>> queue;
|
||||||
queue.push_back({cvt, targetType.getEncoding()});
|
queue.emplace_back(cvt, targetType.getEncoding());
|
||||||
int numCvts = 1;
|
int numCvts = 1;
|
||||||
while (!queue.empty()) {
|
while (!queue.empty()) {
|
||||||
Operation *currOp;
|
Operation *currOp;
|
||||||
@@ -341,7 +341,7 @@ public:
|
|||||||
continue;
|
continue;
|
||||||
// we add one expensive conversion for the current operand
|
// we add one expensive conversion for the current operand
|
||||||
numCvts += 1;
|
numCvts += 1;
|
||||||
queue.push_back({opArgI, newEncoding});
|
queue.emplace_back(opArgI, newEncoding);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// if rematerialization would add more conversions than it removes
|
// if rematerialization would add more conversions than it removes
|
||||||
@@ -351,8 +351,8 @@ public:
|
|||||||
|
|
||||||
SmallVector<Value, 4> sortedValues;
|
SmallVector<Value, 4> sortedValues;
|
||||||
SetVector<Operation *> tmp;
|
SetVector<Operation *> tmp;
|
||||||
for (auto it = toConvert.begin(); it != toConvert.end(); ++it) {
|
for (auto &item : toConvert) {
|
||||||
Value v = it->first;
|
Value v = item.first;
|
||||||
if (v.getDefiningOp())
|
if (v.getDefiningOp())
|
||||||
tmp.insert(v.getDefiningOp());
|
tmp.insert(v.getDefiningOp());
|
||||||
else
|
else
|
||||||
@@ -393,7 +393,7 @@ public:
|
|||||||
|
|
||||||
class MoveConvertOutOfLoop : public mlir::RewritePattern {
|
class MoveConvertOutOfLoop : public mlir::RewritePattern {
|
||||||
public:
|
public:
|
||||||
MoveConvertOutOfLoop(mlir::MLIRContext *context)
|
explicit MoveConvertOutOfLoop(mlir::MLIRContext *context)
|
||||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
|
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
|
||||||
|
|
||||||
SmallVector<Value, 4>
|
SmallVector<Value, 4>
|
||||||
@@ -406,7 +406,7 @@ public:
|
|||||||
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||||
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
|
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
|
||||||
// Clone for loop
|
// Clone for loop
|
||||||
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
auto newForOp = rewriter.create<scf::ForOp>(
|
||||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||||
forOp.getStep(), newInitArgs);
|
forOp.getStep(), newInitArgs);
|
||||||
newForOp->moveBefore(forOp);
|
newForOp->moveBefore(forOp);
|
||||||
@@ -455,7 +455,7 @@ public:
|
|||||||
mlir::PatternRewriter &rewriter) const override {
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
auto forOp = cast<scf::ForOp>(op);
|
auto forOp = cast<scf::ForOp>(op);
|
||||||
auto iterArgs = forOp.getRegionIterArgs();
|
auto iterArgs = forOp.getRegionIterArgs();
|
||||||
for (auto iterArg : llvm::enumerate(iterArgs)) {
|
for (const auto &iterArg : llvm::enumerate(iterArgs)) {
|
||||||
// if (iterArg.index() != 1)
|
// if (iterArg.index() != 1)
|
||||||
// continue;
|
// continue;
|
||||||
// skip non-tensor types
|
// skip non-tensor types
|
||||||
@@ -517,7 +517,7 @@ public:
|
|||||||
|
|
||||||
class RematerializeForward : public mlir::RewritePattern {
|
class RematerializeForward : public mlir::RewritePattern {
|
||||||
public:
|
public:
|
||||||
RematerializeForward(mlir::MLIRContext *context)
|
explicit RematerializeForward(mlir::MLIRContext *context)
|
||||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||||
2, context) {}
|
2, context) {}
|
||||||
|
|
||||||
@@ -584,7 +584,7 @@ public:
|
|||||||
//
|
//
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
namespace {
|
namespace {
|
||||||
static int computeCapabilityToMMAVersion(int computeCapability) {
|
int computeCapabilityToMMAVersion(int computeCapability) {
|
||||||
if (computeCapability < 80) {
|
if (computeCapability < 80) {
|
||||||
return 1;
|
return 1;
|
||||||
} else if (computeCapability < 90) {
|
} else if (computeCapability < 90) {
|
||||||
@@ -595,9 +595,7 @@ static int computeCapabilityToMMAVersion(int computeCapability) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<int64_t, 2>
|
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
|
||||||
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
|
||||||
int numWarps) {
|
|
||||||
if (version == 1)
|
if (version == 1)
|
||||||
return {16, 16};
|
return {16, 16};
|
||||||
else if (version == 2)
|
else if (version == 2)
|
||||||
@@ -608,12 +606,11 @@ mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
|
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
|
||||||
const ArrayRef<int64_t> shape,
|
|
||||||
int numWarps) {
|
int numWarps) {
|
||||||
SmallVector<unsigned, 2> ret = {1, 1};
|
SmallVector<unsigned, 2> ret = {1, 1};
|
||||||
SmallVector<int64_t, 2> shapePerWarp =
|
SmallVector<int64_t, 2> shapePerWarp =
|
||||||
mmaVersionToShapePerWarp(1, shape, numWarps);
|
mmaVersionToShapePerWarp(1 /*version*/);
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
do {
|
do {
|
||||||
changed = false;
|
changed = false;
|
||||||
@@ -669,7 +666,7 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
|||||||
|
|
||||||
class OptimizeBlockedToShared : public mlir::RewritePattern {
|
class OptimizeBlockedToShared : public mlir::RewritePattern {
|
||||||
public:
|
public:
|
||||||
OptimizeBlockedToShared(mlir::MLIRContext *context)
|
explicit OptimizeBlockedToShared(mlir::MLIRContext *context)
|
||||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||||
context) {}
|
context) {}
|
||||||
|
|
||||||
@@ -717,7 +714,7 @@ public:
|
|||||||
|
|
||||||
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
||||||
public:
|
public:
|
||||||
OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||||
context) {}
|
context) {}
|
||||||
|
|
||||||
@@ -729,11 +726,12 @@ public:
|
|||||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||||
// order
|
// order
|
||||||
ArrayRef<unsigned> order;
|
ArrayRef<unsigned> order;
|
||||||
if(auto srcBlockedLayout =
|
if (auto srcBlockedLayout =
|
||||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||||
order = srcBlockedLayout.getOrder();
|
order = srcBlockedLayout.getOrder();
|
||||||
else if(auto srcSharedLayout =
|
else if (auto srcSharedLayout =
|
||||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
|
srcType.getEncoding()
|
||||||
|
.dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||||
order = srcSharedLayout.getOrder();
|
order = srcSharedLayout.getOrder();
|
||||||
else
|
else
|
||||||
return failure();
|
return failure();
|
||||||
@@ -742,20 +740,18 @@ public:
|
|||||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||||
if (!dstDotOperandLayout)
|
if (!dstDotOperandLayout)
|
||||||
return failure();
|
return failure();
|
||||||
unsigned opIdx = dstDotOperandLayout.getOpIdx();
|
if (!dstDotOperandLayout.getIsMMAv1Row())
|
||||||
if(!dstDotOperandLayout.getIsMMAv1Row())
|
|
||||||
return failure();
|
return failure();
|
||||||
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
bool isMMAv1Row =
|
||||||
if((order[0] == 1 && isMMAv1Row) ||
|
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||||
(order[0] == 0 && !isMMAv1Row))
|
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
|
||||||
return failure();
|
return failure();
|
||||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||||
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
|
op->getContext(), dstDotOperandLayout.getOpIdx(),
|
||||||
newIsRow);
|
dstDotOperandLayout.getParent(), newIsRow);
|
||||||
auto newDstType = RankedTensorType::get(
|
auto newDstType = RankedTensorType::get(
|
||||||
dstType.getShape(),
|
dstType.getShape(), dstType.getElementType(), newDstEncoding);
|
||||||
dstType.getElementType(), newDstEncoding);
|
|
||||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||||
op->getLoc(), newDstType, cvt.getOperand());
|
op->getLoc(), newDstType, cvt.getOperand());
|
||||||
rewriter.replaceOp(op, newCvt.getResult());
|
rewriter.replaceOp(op, newCvt.getResult());
|
||||||
@@ -763,7 +759,6 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class BlockedToMMA : public mlir::RewritePattern {
|
class BlockedToMMA : public mlir::RewritePattern {
|
||||||
int computeCapability;
|
int computeCapability;
|
||||||
|
|
||||||
@@ -777,7 +772,7 @@ public:
|
|||||||
int version, int numWarps) {
|
int version, int numWarps) {
|
||||||
switch (version) {
|
switch (version) {
|
||||||
case 1:
|
case 1:
|
||||||
return warpsPerTileV1(dotOp, shape, numWarps);
|
return warpsPerTileV1(shape, numWarps);
|
||||||
case 2:
|
case 2:
|
||||||
return warpsPerTileV2(dotOp, shape, numWarps);
|
return warpsPerTileV2(dotOp, shape, numWarps);
|
||||||
default:
|
default:
|
||||||
@@ -821,27 +816,31 @@ public:
|
|||||||
Value b = dotOp.b();
|
Value b = dotOp.b();
|
||||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||||
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
|
auto oldAOrder = oldAType.getEncoding()
|
||||||
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
|
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||||
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
|
.getParent()
|
||||||
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
|
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||||
|
.getOrder();
|
||||||
|
auto oldBOrder = oldBType.getEncoding()
|
||||||
|
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||||
|
.getParent()
|
||||||
|
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||||
|
.getOrder();
|
||||||
Attribute isMMAv1RowA;
|
Attribute isMMAv1RowA;
|
||||||
Attribute isMMAv1RowB;
|
Attribute isMMAv1RowB;
|
||||||
if(version == 1){
|
if (version == 1) {
|
||||||
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
||||||
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto newAType = RankedTensorType::get(
|
auto newAType = RankedTensorType::get(
|
||||||
oldAType.getShape(), oldAType.getElementType(),
|
oldAType.getShape(), oldAType.getElementType(),
|
||||||
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
|
triton::gpu::DotOperandEncodingAttr::get(
|
||||||
newRetType.getEncoding(),
|
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
|
||||||
isMMAv1RowA));
|
|
||||||
auto newBType = RankedTensorType::get(
|
auto newBType = RankedTensorType::get(
|
||||||
oldBType.getShape(), oldBType.getElementType(),
|
oldBType.getShape(), oldBType.getElementType(),
|
||||||
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
|
triton::gpu::DotOperandEncodingAttr::get(
|
||||||
newRetType.getEncoding(),
|
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
|
||||||
isMMAv1RowB));
|
|
||||||
|
|
||||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||||
@@ -857,9 +856,8 @@ public:
|
|||||||
class FixupLoop : public mlir::RewritePattern {
|
class FixupLoop : public mlir::RewritePattern {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
FixupLoop(mlir::MLIRContext *context)
|
explicit FixupLoop(mlir::MLIRContext *context)
|
||||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 2,
|
: mlir::RewritePattern(scf::ForOp::getOperationName(), 2, context) {}
|
||||||
context) {}
|
|
||||||
|
|
||||||
mlir::LogicalResult
|
mlir::LogicalResult
|
||||||
matchAndRewrite(mlir::Operation *op,
|
matchAndRewrite(mlir::Operation *op,
|
||||||
@@ -869,17 +867,17 @@ public:
|
|||||||
// Rewrite init argument
|
// Rewrite init argument
|
||||||
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
||||||
bool shouldRematerialize = false;
|
bool shouldRematerialize = false;
|
||||||
for(size_t i = 0; i < newInitArgs.size(); i++){
|
for (size_t i = 0; i < newInitArgs.size(); i++) {
|
||||||
auto initArg = newInitArgs[i];
|
auto initArg = newInitArgs[i];
|
||||||
auto regionArg = forOp.getRegionIterArgs()[i];
|
auto regionArg = forOp.getRegionIterArgs()[i];
|
||||||
if(newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()){
|
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) {
|
||||||
shouldRematerialize = true;
|
shouldRematerialize = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(!shouldRematerialize)
|
if (!shouldRematerialize)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
||||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||||
forOp.getStep(), newInitArgs);
|
forOp.getStep(), newInitArgs);
|
||||||
@@ -894,8 +892,6 @@ public:
|
|||||||
}
|
}
|
||||||
rewriter.replaceOp(forOp, newForOp.getResults());
|
rewriter.replaceOp(forOp, newForOp.getResults());
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user