[Triton-MLIR][BACKEND] some code clean on the backend (#978)

This commit is contained in:
Yan Chunwei
2022-12-12 17:46:16 +08:00
committed by GitHub
parent e5cfa0f633
commit 0cfe909df8
4 changed files with 97 additions and 137 deletions

View File

@@ -105,11 +105,9 @@ struct DotOpMmaV1ConversionHelper {
}
// Get the number of fp16x2 elements for $a.
// \param shapeTransed: the shape or reordered shape if transpose needed.
// \param shapeTransed: A's shape or reordered shape if transpose needed.
// \param orderTransed: the order or reordered order if transpose needed.
unsigned getNumM(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
bool isARow = orderTransed[0] != 0;
unsigned getNumM(ArrayRef<int64_t> shapeTransed, bool isARow) const {
AParam param(isARow);
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
@@ -117,11 +115,9 @@ struct DotOpMmaV1ConversionHelper {
}
// Get the number of fp16x2 elements for $b.
// \param shapeTransed: the shape or reordered shape if transpose needed.
// \param shapeTransed: B' shape or reordered shape if transpose needed.
// \param orderTransed: the order or reordered order if transpose needed.
unsigned getNumN(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
bool isBRow = orderTransed[0] != 0;
unsigned getNumN(ArrayRef<int64_t> shapeTransed, bool isBRow) const {
BParam param(isBRow);
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
@@ -130,7 +126,7 @@ struct DotOpMmaV1ConversionHelper {
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
int numM = getNumM(shapeTransed, orderTransed);
int numM = getNumM(shapeTransed, orderTransed[0] == 1);
int NK = shapeTransed[1];
// NOTE: We couldn't get the vec from the shared layout.
@@ -143,7 +139,7 @@ struct DotOpMmaV1ConversionHelper {
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
ArrayRef<unsigned> orderTransed) const {
unsigned numN = getNumN(shapeTransed, orderTransed);
unsigned numN = getNumN(shapeTransed, orderTransed[0] == 1);
int NK = shapeTransed[0];
// NOTE: We couldn't get the vec from the shared layout.
// int vecB = sharedLayout.getVec();
@@ -1451,7 +1447,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
}
};
unsigned numM = getNumM(shape, order);
unsigned numM = getNumM(shape, order[0] == 1);
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
loadA(m, k);
@@ -1563,7 +1559,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
}
};
unsigned numN = getNumN(shape, order);
unsigned numN = getNumN(shape, order[0] == 1);
for (unsigned k = 0; k < NK; k += 4)
for (unsigned n = 0; n < numN / 2; ++n) {
if (!hbs.count({n, k}))