[FRONTEND][BACKEND] Clean-up transpositions (#953)

This commit is contained in:
Philippe Tillet
2022-12-06 09:32:13 -08:00
committed by GitHub
parent 16e973edf2
commit 532e10cf87
12 changed files with 31 additions and 53 deletions

View File

@@ -942,11 +942,6 @@ struct MMA16816ConversionHelper {
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
aTensorTy.getShape().end());
// TODO[Superjomn]: transA cannot be accessed in ConvertLayoutOp.
bool transA = false;
if (transA) {
std::swap(shape[0], shape[1]);
}
ValueTable ha;
std::function<void(int, int)> loadFn;
@@ -1052,8 +1047,6 @@ struct MMA16816ConversionHelper {
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
aTensorTy.getShape().end());
if (op.transA())
std::swap(aShape[0], aShape[1]);
auto dShape = dTensorTy.getShape();
@@ -1462,8 +1455,8 @@ Value DotOpMmaV1ConversionHelper::loadB(
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;
// TODO[Superjomn]: Support the case when isBVec4=false later