|
|
|
@@ -39,15 +39,6 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
|
|
|
|
using ::mlir::triton::gpu::MmaEncodingAttr;
|
|
|
|
|
using ::mlir::triton::gpu::SharedEncodingAttr;
|
|
|
|
|
|
|
|
|
|
// Forward declaration necessary functions locates in TritonGPUToLLVM.cpp .
|
|
|
|
|
llvm::SmallVector<mlir::Value>
|
|
|
|
|
getElementsFromStruct(mlir::Location loc, mlir::Value llvmStruct,
|
|
|
|
|
mlir::ConversionPatternRewriter &rewriter);
|
|
|
|
|
|
|
|
|
|
mlir::LLVM::SharedMemoryObject
|
|
|
|
|
getSharedMemoryObjectFromStruct(mlir::Location loc, mlir::Value llvmStruct,
|
|
|
|
|
mlir::ConversionPatternRewriter &rewriter);
|
|
|
|
|
|
|
|
|
|
// Helper for conversion of DotOp with mma<version=1>, that is sm<80
|
|
|
|
|
struct DotOpMmaV1ConversionHelper {
|
|
|
|
|
MmaEncodingAttr mmaLayout;
|
|
|
|
@@ -710,17 +701,13 @@ public:
|
|
|
|
|
if (kOrder == 1) {
|
|
|
|
|
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
|
|
|
|
|
elems[1] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
|
|
|
|
|
elems[2] =
|
|
|
|
|
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
|
|
|
|
elems[3] =
|
|
|
|
|
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
|
|
|
|
elems[2] = load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
|
|
|
|
elems[3] = load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
|
|
|
|
} else {
|
|
|
|
|
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
|
|
|
|
|
elems[2] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
|
|
|
|
|
elems[1] =
|
|
|
|
|
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
|
|
|
|
elems[3] =
|
|
|
|
|
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
|
|
|
|
elems[1] = load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
|
|
|
|
elems[3] = load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
|
|
|
|
}
|
|
|
|
|
return {elems[0], elems[1], elems[2], elems[3]};
|
|
|
|
|
|
|
|
|
@@ -952,7 +939,6 @@ struct MMA16816ConversionHelper {
|
|
|
|
|
// Loading $a from smem to registers, returns a LLVM::Struct.
|
|
|
|
|
Value loadA(Value tensor, const SharedMemoryObject &smemObj) const {
|
|
|
|
|
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
|
|
|
|
auto layout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
|
|
|
|
|
aTensorTy.getShape().end());
|
|
|
|
@@ -973,12 +959,13 @@ struct MMA16816ConversionHelper {
|
|
|
|
|
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
|
|
|
|
Value warpM = getWarpM(shape[0]);
|
|
|
|
|
// load from smem
|
|
|
|
|
int wpt = std::min<int>(mmaLayout.getWarpsPerCTA()[0], shape[0] / matShapeM);
|
|
|
|
|
loadFn = getLoadMatrixFn(
|
|
|
|
|
tensor, smemObj, mmaLayout, wpt /*wpt*/,
|
|
|
|
|
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/,
|
|
|
|
|
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/,
|
|
|
|
|
true /*isA*/);
|
|
|
|
|
int wpt =
|
|
|
|
|
std::min<int>(mmaLayout.getWarpsPerCTA()[0], shape[0] / matShapeM);
|
|
|
|
|
loadFn =
|
|
|
|
|
getLoadMatrixFn(tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/,
|
|
|
|
|
{mmaInstrM, mmaInstrK} /*instrShape*/,
|
|
|
|
|
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/,
|
|
|
|
|
ha /*vals*/, true /*isA*/);
|
|
|
|
|
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
|
|
|
|
// load from registers, used in gemm fuse
|
|
|
|
|
// TODO(Superjomn) Port the logic.
|
|
|
|
@@ -1000,7 +987,6 @@ struct MMA16816ConversionHelper {
|
|
|
|
|
Value loadB(Value tensor, const SharedMemoryObject &smemObj) {
|
|
|
|
|
ValueTable hb;
|
|
|
|
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
|
|
|
|
auto layout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
|
|
|
|
|
tensorTy.getShape().end());
|
|
|
|
@@ -1017,12 +1003,13 @@ struct MMA16816ConversionHelper {
|
|
|
|
|
int numRepN = getNumRepN(tensorTy, shape[1]);
|
|
|
|
|
|
|
|
|
|
Value warpN = getWarpN(shape[1]);
|
|
|
|
|
int wpt = std::min<int>(mmaLayout.getWarpsPerCTA()[1], shape[1] / matShapeN);
|
|
|
|
|
auto loadFn = getLoadMatrixFn(
|
|
|
|
|
tensor, smemObj, mmaLayout, wpt /*wpt*/,
|
|
|
|
|
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/,
|
|
|
|
|
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/,
|
|
|
|
|
false /*isA*/);
|
|
|
|
|
int wpt =
|
|
|
|
|
std::min<int>(mmaLayout.getWarpsPerCTA()[1], shape[1] / matShapeN);
|
|
|
|
|
auto loadFn =
|
|
|
|
|
getLoadMatrixFn(tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/,
|
|
|
|
|
{mmaInstrK, mmaInstrN} /*instrShape*/,
|
|
|
|
|
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/,
|
|
|
|
|
hb /*vals*/, false /*isA*/);
|
|
|
|
|
|
|
|
|
|
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
|
|
|
|
for (int k = 0; k < numRepK; ++k)
|
|
|
|
@@ -1167,6 +1154,7 @@ private:
|
|
|
|
|
SmallVector<Value> ptrs(numPtrs);
|
|
|
|
|
|
|
|
|
|
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
|
|
|
|
|
|
|
|
|
Type smemPtrTy = helper.getShemPtrTy();
|
|
|
|
|
for (int i = 0; i < numPtrs; ++i) {
|
|
|
|
|
ptrs[i] =
|
|
|
|
@@ -1292,7 +1280,6 @@ struct DotOpFMAConversionHelper {
|
|
|
|
|
auto blockedLayout = dotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
|
|
|
|
auto shapePerCTA = getShapePerCTA(blockedLayout);
|
|
|
|
|
auto sizePerThread = getSizePerThread(blockedLayout);
|
|
|
|
|
auto order = blockedLayout.getOrder();
|
|
|
|
|
|
|
|
|
|
// TODO[Superjomn]: we assume the k aixs is fixed for $a and $b here, fix it
|
|
|
|
|
// if not.
|
|
|
|
@@ -1342,17 +1329,15 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|
|
|
|
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
|
|
|
|
sharedLayout.getOrder().end());
|
|
|
|
|
|
|
|
|
|
// TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp.
|
|
|
|
|
bool transA = false;
|
|
|
|
|
if (transA) {
|
|
|
|
|
std::swap(shape[0], shape[1]);
|
|
|
|
|
std::swap(order[0], order[1]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
|
|
|
|
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
|
|
|
|
|
|
|
|
|
bool isARow = order[0] != 0;
|
|
|
|
|
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
|
|
|
|
|
// TODO[Superjomn]: Support the case when isAVec4=false later
|
|
|
|
|
// Currently, we only support ld.v2, for the mma layout varies with different
|
|
|
|
|
// ld vector width.
|
|
|
|
|
isAVec4 = true;
|
|
|
|
|
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
|
|
|
|
|
|
|
|
|
SmallVector<int> fpw({2, 2, 1});
|
|
|
|
@@ -1362,6 +1347,16 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|
|
|
|
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
|
|
|
|
|
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
|
|
|
|
|
|
|
|
|
|
auto [offsetAM, offsetAK, _0, _1] =
|
|
|
|
|
computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc);
|
|
|
|
|
// TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp.
|
|
|
|
|
bool transA = false;
|
|
|
|
|
if (transA) {
|
|
|
|
|
std::swap(shape[0], shape[1]);
|
|
|
|
|
std::swap(offsetAM, offsetAK);
|
|
|
|
|
std::swap(order[0], order[1]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int vecA = sharedLayout.getVec();
|
|
|
|
|
|
|
|
|
|
auto strides = smemObj.strides;
|
|
|
|
@@ -1373,9 +1368,6 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|
|
|
|
int strideRepM = wpt[0] * fpw[0] * 8;
|
|
|
|
|
int strideRepK = 1;
|
|
|
|
|
|
|
|
|
|
auto [offsetAM, offsetAK, _0, _1] =
|
|
|
|
|
computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc);
|
|
|
|
|
|
|
|
|
|
// swizzling
|
|
|
|
|
int perPhaseA = sharedLayout.getPerPhase();
|
|
|
|
|
int maxPhaseA = sharedLayout.getMaxPhase();
|
|
|
|
@@ -1398,19 +1390,14 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Type f16x2Ty = vec_ty(f16_ty, 2);
|
|
|
|
|
// One thread get 8 elements as result
|
|
|
|
|
Type retTy =
|
|
|
|
|
LLVM::LLVMStructType::getLiteral(ctx, SmallVector(8, type::f32Ty(ctx)));
|
|
|
|
|
|
|
|
|
|
// prepare arguments
|
|
|
|
|
SmallVector<Value> ptrA(numPtrA);
|
|
|
|
|
|
|
|
|
|
std::map<std::pair<int, int>, std::pair<Value, Value>> has;
|
|
|
|
|
auto smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
|
|
|
|
for (int i = 0; i < numPtrA; i++)
|
|
|
|
|
ptrA[i] = gep(ptr_ty(f16_ty), smem, offA[i]);
|
|
|
|
|
ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]);
|
|
|
|
|
|
|
|
|
|
auto instrShape = getMmaInstrShape();
|
|
|
|
|
unsigned numM = std::max<int>(rep[0] * shape[0] / (spw[0] * wpt[0]), 1);
|
|
|
|
|
|
|
|
|
|
Type f16PtrTy = ptr_ty(f16_ty);
|
|
|
|
@@ -1420,7 +1407,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|
|
|
|
};
|
|
|
|
|
auto loadA = [&](int m, int k) {
|
|
|
|
|
int offidx = (isARow ? k / 4 : m) % numPtrA;
|
|
|
|
|
Value thePtrA = gep(f16PtrTy, smem, offA[offidx]);
|
|
|
|
|
Value thePtrA = gep(f16PtrTy, smemBase, offA[offidx]);
|
|
|
|
|
|
|
|
|
|
int stepAM = isARow ? m : m / numPtrA * numPtrA;
|
|
|
|
|
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
|
|
|
|
@@ -1446,12 +1433,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|
|
|
|
|
|
|
|
|
for (unsigned k = 0; k < NK; k += 4)
|
|
|
|
|
for (unsigned m = 0; m < numM / 2; ++m)
|
|
|
|
|
if (!has.count({m, k}))
|
|
|
|
|
loadA(m, k);
|
|
|
|
|
loadA(m, k);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> elems;
|
|
|
|
|
elems.reserve(has.size() * 2);
|
|
|
|
|
auto vecTy = vec_ty(f16_ty, 2);
|
|
|
|
|
for (auto item : has) { // has is a map, the key should be ordered.
|
|
|
|
|
elems.push_back(item.second.first);
|
|
|
|
|
elems.push_back(item.second.second);
|
|
|
|
@@ -1466,7 +1451,6 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|
|
|
|
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
// smem
|
|
|
|
|
Value smem = smemObj.base;
|
|
|
|
|
auto strides = smemObj.strides;
|
|
|
|
|
|
|
|
|
|
auto *ctx = rewriter.getContext();
|
|
|
|
@@ -1478,21 +1462,20 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|
|
|
|
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
|
|
|
|
sharedLayout.getOrder().end());
|
|
|
|
|
|
|
|
|
|
// TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp.
|
|
|
|
|
bool transB = false;
|
|
|
|
|
|
|
|
|
|
if (transB) {
|
|
|
|
|
std::swap(order[0], order[1]);
|
|
|
|
|
std::swap(shape[0], shape[1]);
|
|
|
|
|
}
|
|
|
|
|
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
|
|
|
|
|
|
|
|
|
bool isBRow = order[0] != 0;
|
|
|
|
|
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
|
|
|
|
// TODO[Superjomn]: Support the case when isBVec4=false later
|
|
|
|
|
// Currently, we only support ld.v2, for the mma layout varies with different
|
|
|
|
|
// ld vector width.
|
|
|
|
|
isBVec4 = true;
|
|
|
|
|
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
|
|
|
|
SmallVector<int> fpw({2, 2, 1});
|
|
|
|
|
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
|
|
|
|
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
|
|
|
|
int vecB = sharedLayout.getVec();
|
|
|
|
|
|
|
|
|
|
Value strideBN = isBRow ? i32_val(1) : strides[1];
|
|
|
|
|
Value strideBK = isBRow ? strides[0] : i32_val(1);
|
|
|
|
|
Value strideB0 = isBRow ? strideBN : strideBK;
|
|
|
|
@@ -1500,24 +1483,29 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|
|
|
|
int strideRepN = wpt[1] * fpw[1] * 8;
|
|
|
|
|
int strideRepK = 1;
|
|
|
|
|
|
|
|
|
|
// TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp.
|
|
|
|
|
bool transB = false;
|
|
|
|
|
|
|
|
|
|
auto [_0, _1, offsetBN, offsetBK] =
|
|
|
|
|
computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc);
|
|
|
|
|
if (transB) {
|
|
|
|
|
std::swap(order[0], order[1]);
|
|
|
|
|
std::swap(shape[0], shape[1]);
|
|
|
|
|
std::swap(offsetBK, offsetBN);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// swizzling
|
|
|
|
|
int perPhaseA = sharedLayout.getPerPhase();
|
|
|
|
|
int maxPhaseA = sharedLayout.getMaxPhase();
|
|
|
|
|
int perPhaseB = sharedLayout.getPerPhase();
|
|
|
|
|
int maxPhaseB = sharedLayout.getMaxPhase();
|
|
|
|
|
int stepB0 = isBRow ? strideRepN : strideRepK;
|
|
|
|
|
int numPtrB = std::max(2 * perPhaseB * maxPhaseB / stepB0, 1);
|
|
|
|
|
int NK = shape[0];
|
|
|
|
|
|
|
|
|
|
auto [_0, _1, offsetBN, offsetBK] =
|
|
|
|
|
computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc);
|
|
|
|
|
if (transB)
|
|
|
|
|
std::swap(offsetBK, offsetBN);
|
|
|
|
|
|
|
|
|
|
Value offB0 = isBRow ? offsetBN : offsetBK;
|
|
|
|
|
Value offB1 = isBRow ? offsetBK : offsetBN;
|
|
|
|
|
Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB));
|
|
|
|
|
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
|
|
|
|
|
|
|
|
|
offB0 = add(offB0, cSwizzleOffset);
|
|
|
|
|
SmallVector<Value> offB(numPtrB);
|
|
|
|
|
for (int i = 0; i < numPtrB; ++i) {
|
|
|
|
@@ -1549,6 +1537,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|
|
|
|
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
|
|
|
|
|
mul(i32_val(stepBK), strideBK));
|
|
|
|
|
Value pb = gep(f16PtrTy, thePtrB, offset);
|
|
|
|
|
|
|
|
|
|
Value hb =
|
|
|
|
|
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
|
|
|
|
|
// record lds that needs to be moved
|
|
|
|
@@ -1651,9 +1640,12 @@ DotOpMmaV1ConversionHelper::extractLoadedOperand(
|
|
|
|
|
SmallVector<Value> elems =
|
|
|
|
|
getElementsFromStruct(llStruct.getLoc(), llStruct, rewriter);
|
|
|
|
|
|
|
|
|
|
for (int k = 0, offset = 0, i = 0; k < NK && offset < elems.size();
|
|
|
|
|
k += 4, i++, offset += 2) {
|
|
|
|
|
rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]);
|
|
|
|
|
int offset = 0;
|
|
|
|
|
for (int i = 0; offset < elems.size(); ++i) {
|
|
|
|
|
for (int k = 0; k < NK; k += 4) {
|
|
|
|
|
rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]);
|
|
|
|
|
offset += 2;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return rcds;
|
|
|
|
@@ -1675,9 +1667,7 @@ Value DotOpFMAConversionHelper::loadA(
|
|
|
|
|
int strideAK = isARow ? 1 : aShape[0];
|
|
|
|
|
int strideA0 = isARow ? strideAK : strideAM;
|
|
|
|
|
int strideA1 = isARow ? strideAM : strideAK;
|
|
|
|
|
int lda = isARow ? strideAM : strideAK;
|
|
|
|
|
int aNumPtr = 8;
|
|
|
|
|
int bNumPtr = 8;
|
|
|
|
|
int NK = aShape[1];
|
|
|
|
|
|
|
|
|
|
auto shapePerCTA = getShapePerCTA(dLayout);
|
|
|
|
@@ -1686,13 +1676,11 @@ Value DotOpFMAConversionHelper::loadA(
|
|
|
|
|
Value _0 = i32_val(0);
|
|
|
|
|
|
|
|
|
|
Value mContig = i32_val(sizePerThread[order[1]]);
|
|
|
|
|
Value nContig = i32_val(sizePerThread[order[0]]);
|
|
|
|
|
|
|
|
|
|
// threadId in blocked layout
|
|
|
|
|
auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc);
|
|
|
|
|
|
|
|
|
|
Value threadIdM = threadIds[0];
|
|
|
|
|
Value threadIdN = threadIds[1];
|
|
|
|
|
|
|
|
|
|
Value offA0 = isARow ? _0 : mul(threadIdM, mContig);
|
|
|
|
|
Value offA1 = isARow ? mul(threadIdM, mContig) : _0;
|
|
|
|
@@ -1745,7 +1733,6 @@ Value DotOpFMAConversionHelper::loadB(
|
|
|
|
|
int strideBK = isBRow ? bShape[1] : 1;
|
|
|
|
|
int strideB0 = isBRow ? strideBN : strideBK;
|
|
|
|
|
int strideB1 = isBRow ? strideBK : strideBN;
|
|
|
|
|
int ldb = isBRow ? strideBK : strideBN;
|
|
|
|
|
int bNumPtr = 8;
|
|
|
|
|
int NK = bShape[0];
|
|
|
|
|
|
|
|
|
@@ -1754,7 +1741,6 @@ Value DotOpFMAConversionHelper::loadB(
|
|
|
|
|
|
|
|
|
|
Value _0 = i32_val(0);
|
|
|
|
|
|
|
|
|
|
Value mContig = i32_val(sizePerThread[order[1]]);
|
|
|
|
|
Value nContig = i32_val(sizePerThread[order[0]]);
|
|
|
|
|
|
|
|
|
|
// threadId in blocked layout
|
|
|
|
|