[Triton-MLIR][BACKEND] Tiny patch for MMAv1 and code clean (#964)

This PR:

- Several fix on MMAV1 code
- Remove the env `TRITON_STATIC_LOOP_UNROLLING` in v100 CI since the
pipeline pass works now
- some code clean
This commit is contained in:
Yan Chunwei
2022-12-08 16:39:32 +08:00
committed by GitHub
parent 18e683d9bb
commit f0885e9caf
3 changed files with 116 additions and 95 deletions

View File

@@ -43,12 +43,51 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
struct DotOpMmaV1ConversionHelper {
MmaEncodingAttr mmaLayout;
ArrayRef<unsigned> wpt;
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
explicit DotOpMmaV1ConversionHelper(MmaEncodingAttr mmaLayout)
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {}
// Help to share some variables across multiple functions for A.
struct AParam {
SmallVector<int> rep;
SmallVector<int> spw;
// TODO[Superjomn]: Support the case when isAVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with
// different ld vector width.
// bool isAVec4 = !isARow && shapeTransed[orderTransed[0]] <= 16;
const bool isAVec4{true};
explicit AParam(bool isARow) {
int packSize0 = (isARow || isAVec4) ? 1 : 2;
int repM = 2 * packSize0;
int repK = 1;
int spwM = fpw[0] * 4 * repM;
rep.assign({repM, 0, repK});
spw.assign({spwM, 0, 1});
}
};
// Help to share some variables across multiple functions for A.
struct BParam {
SmallVector<int> rep;
SmallVector<int> spw;
// TODO[Superjomn]: Support the case when isBVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with
// different ld vector width.
// bool isBVec4 = isBRow && shapeTransed[orderTransed[0]] <= 16;
const bool isBVec4{true};
explicit BParam(bool isBRow) {
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
rep.assign({0, 2 * packSize1, 1});
spw.assign({0, fpw[1] * 4 * rep[1], 1});
}
};
int getRepM(int M) const {
return std::max<int>(M / (wpt[0] * instrShape[0]), 1);
}
@@ -65,29 +104,34 @@ struct DotOpMmaV1ConversionHelper {
return struct_ty(SmallVector<Type>{8, fp32Ty});
}
// number of fp16x2 elements for $a.
int numElemsPerThreadA(RankedTensorType tensorTy) const {
auto shape = tensorTy.getShape();
auto order = getOrder();
// Get the number of fp16x2 elements for $a.
// \param shapeTransed: the shape or reordered shape if transpose needed.
// \param orderTransed: the order or reordered order if transpose needed.
unsigned getNumM(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
bool isARow = orderTransed[0] != 0;
AParam param(isARow);
bool isARow = order[0] != 0;
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
// TODO[Superjomn]: Support the case when isAVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with
// different ld vector width.
isAVec4 = true;
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
return numM;
}
int packSize0 = (isARow || isAVec4) ? 1 : 2;
// Get the number of fp16x2 elements for $b.
// \param shapeTransed: the shape or reordered shape if transpose needed.
// \param orderTransed: the order or reordered order if transpose needed.
unsigned getNumN(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
bool isBRow = orderTransed[0] != 0;
BParam param(isBRow);
SmallVector<int> fpw({2, 2, 1});
int repM = 2 * packSize0;
int repK = 1;
int spwM = fpw[0] * 4 * repM;
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
return numN;
}
int NK = shape[1];
unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]);
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
int numM = getNumM(shapeTransed, orderTransed);
int NK = shapeTransed[1];
// NOTE: We couldn't get the vec from the shared layout.
// int vecA = sharedLayout.getVec();
@@ -97,39 +141,27 @@ struct DotOpMmaV1ConversionHelper {
return (numM / 2) * (NK / 4) * elemsPerLd;
}
// number of fp16x2 elements for $b.
int numElemsPerThreadB(RankedTensorType tensorTy) const {
auto shape = tensorTy.getShape();
auto order = getOrder();
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;
// TODO[Superjomn]: Support the case when isBVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with
// different ld vector width.
isBVec4 = true;
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
unsigned numN = getNumN(shapeTransed, orderTransed);
int NK = shapeTransed[0];
// NOTE: We couldn't get the vec from the shared layout.
// int vecB = sharedLayout.getVec();
// TODO[Superjomn]: Consider the case when vecA > 4
bool vecGt4 = false;
int elemsPerLd = vecGt4 ? 4 : 2;
int NK = shape[0];
unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]);
return (numN / 2) * (NK / 4) * elemsPerLd;
}
// Loading $a from smem to registers, returns a LLVM::Struct.
Value loadA(Value A, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const;
Value loadA(Value A, bool transA, const SharedMemoryObject &smemObj,
Value thread, Location loc,
ConversionPatternRewriter &rewriter) const;
// Loading $b from smem to registers, returns a LLVM::Struct.
Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const;
Value loadB(Value B, bool transB, const SharedMemoryObject &smemObj,
Value thread, Location loc,
ConversionPatternRewriter &rewriter) const;
static ArrayRef<unsigned> getOrder() { return mmaOrder; }
@@ -1321,8 +1353,8 @@ struct DotOpFMAConversionHelper {
};
Value DotOpMmaV1ConversionHelper::loadA(
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
ConversionPatternRewriter &rewriter) const {
Value tensor, bool transA, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const {
auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>();
@@ -1336,24 +1368,11 @@ Value DotOpMmaV1ConversionHelper::loadA(
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
bool isARow = order[0] != 0;
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
// TODO[Superjomn]: Support the case when isAVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with different
// ld vector width.
isAVec4 = true;
int packSize0 = (isARow || isAVec4) ? 1 : 2;
AParam param(isARow);
SmallVector<int> fpw({2, 2, 1});
int repM = 2 * packSize0;
int repK = 1;
int spwM = fpw[0] * 4 * repM;
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
auto [offsetAM, offsetAK, _0, _1] = computeOffsets(
thread, isARow, false, fpw, param.spw, param.rep, rewriter, loc);
auto [offsetAM, offsetAK, _0, _1] =
computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc);
// TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp.
bool transA = false;
if (transA) {
std::swap(shape[0], shape[1]);
std::swap(offsetAM, offsetAK);
@@ -1401,8 +1420,6 @@ Value DotOpMmaV1ConversionHelper::loadA(
for (int i = 0; i < numPtrA; i++)
ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]);
unsigned numM = std::max<int>(rep[0] * shape[0] / (spw[0] * wpt[0]), 1);
Type f16PtrTy = ptr_ty(f16_ty);
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
@@ -1434,6 +1451,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
}
};
unsigned numM = getNumM(shape, order);
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
loadA(m, k);
@@ -1451,8 +1469,8 @@ Value DotOpMmaV1ConversionHelper::loadA(
}
Value DotOpMmaV1ConversionHelper::loadB(
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
ConversionPatternRewriter &rewriter) const {
Value tensor, bool transB, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const {
// smem
auto strides = smemObj.strides;
@@ -1467,17 +1485,9 @@ Value DotOpMmaV1ConversionHelper::loadB(
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;
// TODO[Superjomn]: Support the case when isBVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with different
// ld vector width.
isBVec4 = true;
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
int vecB = sharedLayout.getVec();
BParam param(isBRow);
int vecB = sharedLayout.getVec();
Value strideBN = isBRow ? i32_val(1) : strides[1];
Value strideBK = isBRow ? strides[0] : i32_val(1);
Value strideB0 = isBRow ? strideBN : strideBK;
@@ -1485,11 +1495,8 @@ Value DotOpMmaV1ConversionHelper::loadB(
int strideRepN = wpt[1] * fpw[1] * 8;
int strideRepK = 1;
// TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp.
bool transB = false;
auto [_0, _1, offsetBN, offsetBK] =
computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc);
auto [_0, _1, offsetBN, offsetBK] = computeOffsets(
thread, false, isBRow, fpw, param.spw, param.rep, rewriter, loc);
if (transB) {
std::swap(order[0], order[1]);
std::swap(shape[0], shape[1]);
@@ -1556,7 +1563,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
}
};
unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]);
unsigned numN = getNumN(shape, order);
for (unsigned k = 0; k < NK; k += 4)
for (unsigned n = 0; n < numN / 2; ++n) {
if (!hbs.count({n, k}))