[Triton-MLIR][BACKEND] some code clean on the backend (#978)
This commit is contained in:
@@ -105,11 +105,9 @@ struct DotOpMmaV1ConversionHelper {
|
||||
}
|
||||
|
||||
// Get the number of fp16x2 elements for $a.
|
||||
// \param shapeTransed: the shape or reordered shape if transpose needed.
|
||||
// \param shapeTransed: A's shape or reordered shape if transpose needed.
|
||||
// \param orderTransed: the order or reordered order if transpose needed.
|
||||
unsigned getNumM(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
bool isARow = orderTransed[0] != 0;
|
||||
unsigned getNumM(ArrayRef<int64_t> shapeTransed, bool isARow) const {
|
||||
AParam param(isARow);
|
||||
|
||||
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
|
||||
@@ -117,11 +115,9 @@ struct DotOpMmaV1ConversionHelper {
|
||||
}
|
||||
|
||||
// Get the number of fp16x2 elements for $b.
|
||||
// \param shapeTransed: the shape or reordered shape if transpose needed.
|
||||
// \param shapeTransed: B' shape or reordered shape if transpose needed.
|
||||
// \param orderTransed: the order or reordered order if transpose needed.
|
||||
unsigned getNumN(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
bool isBRow = orderTransed[0] != 0;
|
||||
unsigned getNumN(ArrayRef<int64_t> shapeTransed, bool isBRow) const {
|
||||
BParam param(isBRow);
|
||||
|
||||
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
|
||||
@@ -130,7 +126,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
|
||||
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
int numM = getNumM(shapeTransed, orderTransed);
|
||||
int numM = getNumM(shapeTransed, orderTransed[0] == 1);
|
||||
int NK = shapeTransed[1];
|
||||
|
||||
// NOTE: We couldn't get the vec from the shared layout.
|
||||
@@ -143,7 +139,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
|
||||
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
unsigned numN = getNumN(shapeTransed, orderTransed);
|
||||
unsigned numN = getNumN(shapeTransed, orderTransed[0] == 1);
|
||||
int NK = shapeTransed[0];
|
||||
// NOTE: We couldn't get the vec from the shared layout.
|
||||
// int vecB = sharedLayout.getVec();
|
||||
@@ -1451,7 +1447,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
}
|
||||
};
|
||||
|
||||
unsigned numM = getNumM(shape, order);
|
||||
unsigned numM = getNumM(shape, order[0] == 1);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
loadA(m, k);
|
||||
@@ -1563,7 +1559,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
}
|
||||
};
|
||||
|
||||
unsigned numN = getNumN(shape, order);
|
||||
unsigned numN = getNumN(shape, order[0] == 1);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
if (!hbs.count({n, k}))
|
||||
|
@@ -739,7 +739,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
|
||||
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = bitcast(constVal, srcType);
|
||||
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
||||
@@ -981,7 +980,7 @@ struct LoadOpConversion
|
||||
size_t size = width / valueElemNbits;
|
||||
|
||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
||||
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
Value v = undef(vecTy);
|
||||
for (size_t s = 0; s < size; ++s) {
|
||||
Value falseVal = otherElems[vecStart + ii * size + s];
|
||||
Value sVal = createIndexAttrConstant(
|
||||
@@ -1118,7 +1117,7 @@ struct StoreOpConversion
|
||||
SmallVector<std::pair<Value, std::string>> asmArgs;
|
||||
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
// llWord is a width-len composition
|
||||
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
||||
Value llWord = undef(wordTy);
|
||||
// Insert each value element to the composition
|
||||
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
|
||||
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
|
||||
@@ -1129,10 +1128,7 @@ struct StoreOpConversion
|
||||
elem = bitcast(elem, valueElemTy);
|
||||
|
||||
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
|
||||
llWord =
|
||||
insert_element(wordTy, llWord, elem,
|
||||
rewriter.create<LLVM::ConstantOp>(
|
||||
loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx)));
|
||||
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
|
||||
}
|
||||
llWord = bitcast(llWord, valArgTy);
|
||||
std::string constraint =
|
||||
@@ -3570,43 +3566,27 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
auto DTensorTy = D.getType().cast<RankedTensorType>();
|
||||
SmallVector<int> AShape(ATensorTy.getShape().begin(),
|
||||
ATensorTy.getShape().end());
|
||||
SmallVector<int> BShape(BTensorTy.getShape().begin(),
|
||||
BTensorTy.getShape().end());
|
||||
auto AShape = ATensorTy.getShape();
|
||||
auto BShape = BTensorTy.getShape();
|
||||
auto DShape = DTensorTy.getShape();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
||||
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
|
||||
// TODO[Superjomn]: ld.v4 is not supported.
|
||||
isAVec4 = true;
|
||||
isBVec4 = true;
|
||||
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
SmallVector<int> rep({2 * packSize0, 2 * packSize1, 1});
|
||||
SmallVector<int> spw({fpw[0] * 4 * rep[0], fpw[1] * 4 * rep[1], 1});
|
||||
|
||||
Value loadedA = adaptor.a();
|
||||
Value loadedB = adaptor.b();
|
||||
Value loadedC = adaptor.c();
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
|
||||
unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]);
|
||||
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[1]);
|
||||
unsigned numM = helper.getNumM(AShape, isARow);
|
||||
unsigned numN = helper.getNumN(BShape, isBRow);
|
||||
unsigned NK = AShape[1];
|
||||
|
||||
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter);
|
||||
auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter);
|
||||
auto has = helper.extractLoadedOperand(adaptor.a(), NK, rewriter);
|
||||
auto hbs = helper.extractLoadedOperand(adaptor.b(), NK, rewriter);
|
||||
|
||||
// Initialize accumulators with external values, the acc holds the accumulator
|
||||
// value that is shared between the MMA instructions inside a DotOp, we can
|
||||
// call the order of the values the accumulator-internal order.
|
||||
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
|
||||
SmallVector<Value> acc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
||||
size_t resSize = acc.size();
|
||||
|
||||
// The resVals holds the final result of the DotOp.
|
||||
@@ -3719,38 +3699,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
auto bShape = bTensorTy.getShape();
|
||||
auto cShape = cTensorTy.getShape();
|
||||
|
||||
ValueTable has, hbs;
|
||||
int mShapePerCTA{-1}, nShapePerCTA{-1};
|
||||
int mSizePerThread{-1}, nSizePerThread{-1};
|
||||
ArrayRef<unsigned> aOrder, bOrder;
|
||||
Value llA, llB;
|
||||
BlockedEncodingAttr dLayout =
|
||||
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto order = dLayout.getOrder();
|
||||
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
||||
|
||||
DotOpFMAConversionHelper helper(dLayout);
|
||||
if (auto aDotOpLayout =
|
||||
aTensorTy.getEncoding()
|
||||
.dyn_cast<DotOperandEncodingAttr>()) { // get input from
|
||||
// convert_layout
|
||||
auto bDotOpLayout =
|
||||
bTensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
|
||||
auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
auto aDotOpLayout = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto bDotOpLayout = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
|
||||
assert(bLayout);
|
||||
llA = adaptor.a();
|
||||
llB = adaptor.b();
|
||||
} else if (auto aLayout =
|
||||
aTensorTy.getEncoding()
|
||||
.dyn_cast<SharedEncodingAttr>()) { // load input from smem
|
||||
auto bLayout = bTensorTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||
assert(bLayout);
|
||||
Value thread = getThreadId(rewriter, loc);
|
||||
llA = helper.loadA(A, adaptor.a(), dLayout, thread, loc, rewriter);
|
||||
llB = helper.loadB(B, adaptor.b(), dLayout, thread, loc, rewriter);
|
||||
}
|
||||
Value llA = adaptor.a();
|
||||
Value llB = adaptor.b();
|
||||
|
||||
auto sizePerThread = getSizePerThread(dLayout);
|
||||
auto shapePerCTA = getShapePerCTA(dLayout);
|
||||
@@ -3759,17 +3720,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
int M = aShape[0];
|
||||
int N = bShape[1];
|
||||
|
||||
mShapePerCTA = order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
mSizePerThread =
|
||||
int mShapePerCTA =
|
||||
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
int mSizePerThread =
|
||||
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||
nShapePerCTA = order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
nSizePerThread =
|
||||
int nShapePerCTA =
|
||||
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
||||
int nSizePerThread =
|
||||
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||
|
||||
has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA, mSizePerThread,
|
||||
rewriter, loc);
|
||||
hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA, nSizePerThread,
|
||||
rewriter, loc);
|
||||
auto has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA,
|
||||
mSizePerThread, rewriter, loc);
|
||||
auto hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA,
|
||||
nSizePerThread, rewriter, loc);
|
||||
|
||||
SmallVector<Value> ret = cc;
|
||||
for (unsigned k = 0; k < K; k++) {
|
||||
@@ -3780,7 +3743,6 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
|
||||
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
|
||||
hbs[{n + nn, k}], ret[z]);
|
||||
|
||||
++z;
|
||||
}
|
||||
}
|
||||
@@ -4310,9 +4272,10 @@ struct ExpOpConversionApprox
|
||||
// For FP64 input, call __nv_expf for higher-precision calculation
|
||||
if (elemTy.getIntOrFloatBitWidth() == 64)
|
||||
return {};
|
||||
|
||||
const double log2e = 1.4426950408889634;
|
||||
Value prod =
|
||||
rewriter.create<LLVM::FMulOp>(loc, f32_ty, operands[0], f32_val(log2e));
|
||||
Value prod = fmul(f32_ty, operands[0], f32_val(log2e));
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
|
||||
auto output = ptxBuilder.newOperand("=f");
|
||||
|
@@ -31,6 +31,7 @@
|
||||
#include <numeric>
|
||||
|
||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
||||
// Operators
|
||||
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
|
||||
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
|
||||
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
||||
@@ -40,6 +41,7 @@
|
||||
#define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__)
|
||||
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
|
||||
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
|
||||
#define fmul(...) rewriter.create<LLVM::FMulOp>(loc, __VA_ARGS__)
|
||||
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
|
||||
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
|
||||
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
|
||||
@@ -90,6 +92,8 @@
|
||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||
|
||||
// Types
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define ui32_ty rewriter.getIntegerType(32, false)
|
||||
#define f16_ty rewriter.getF16Type()
|
||||
@@ -102,8 +106,9 @@
|
||||
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
|
||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
|
||||
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
|
||||
|
||||
// Creator for constant
|
||||
// Constants
|
||||
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
|
||||
#define int_val(width, val) \
|
||||
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
|
||||
|
Reference in New Issue
Block a user