[Triton-MLIR][BACKEND] some code clean on the backend (#978)
This commit is contained in:
@@ -105,11 +105,9 @@ struct DotOpMmaV1ConversionHelper {
|
||||
}
|
||||
|
||||
// Get the number of fp16x2 elements for $a.
|
||||
// \param shapeTransed: the shape or reordered shape if transpose needed.
|
||||
// \param shapeTransed: A's shape or reordered shape if transpose needed.
|
||||
// \param orderTransed: the order or reordered order if transpose needed.
|
||||
unsigned getNumM(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
bool isARow = orderTransed[0] != 0;
|
||||
unsigned getNumM(ArrayRef<int64_t> shapeTransed, bool isARow) const {
|
||||
AParam param(isARow);
|
||||
|
||||
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
|
||||
@@ -117,11 +115,9 @@ struct DotOpMmaV1ConversionHelper {
|
||||
}
|
||||
|
||||
// Get the number of fp16x2 elements for $b.
|
||||
// \param shapeTransed: the shape or reordered shape if transpose needed.
|
||||
// \param shapeTransed: B' shape or reordered shape if transpose needed.
|
||||
// \param orderTransed: the order or reordered order if transpose needed.
|
||||
unsigned getNumN(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
bool isBRow = orderTransed[0] != 0;
|
||||
unsigned getNumN(ArrayRef<int64_t> shapeTransed, bool isBRow) const {
|
||||
BParam param(isBRow);
|
||||
|
||||
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
|
||||
@@ -130,7 +126,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
|
||||
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
int numM = getNumM(shapeTransed, orderTransed);
|
||||
int numM = getNumM(shapeTransed, orderTransed[0] == 1);
|
||||
int NK = shapeTransed[1];
|
||||
|
||||
// NOTE: We couldn't get the vec from the shared layout.
|
||||
@@ -143,7 +139,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
|
||||
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
|
||||
ArrayRef<unsigned> orderTransed) const {
|
||||
unsigned numN = getNumN(shapeTransed, orderTransed);
|
||||
unsigned numN = getNumN(shapeTransed, orderTransed[0] == 1);
|
||||
int NK = shapeTransed[0];
|
||||
// NOTE: We couldn't get the vec from the shared layout.
|
||||
// int vecB = sharedLayout.getVec();
|
||||
@@ -1451,7 +1447,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
}
|
||||
};
|
||||
|
||||
unsigned numM = getNumM(shape, order);
|
||||
unsigned numM = getNumM(shape, order[0] == 1);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
loadA(m, k);
|
||||
@@ -1563,7 +1559,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
}
|
||||
};
|
||||
|
||||
unsigned numN = getNumN(shape, order);
|
||||
unsigned numN = getNumN(shape, order[0] == 1);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
if (!hbs.count({n, k}))
|
||||
|
Reference in New Issue
Block a user