[FRONTEND][BACKEND] Clean-up transpositions (#953)

This commit is contained in:
Philippe Tillet
2022-12-06 09:32:13 -08:00
committed by GitHub
parent 16e973edf2
commit 532e10cf87
12 changed files with 31 additions and 53 deletions

View File

@@ -942,11 +942,6 @@ struct MMA16816ConversionHelper {
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
aTensorTy.getShape().end());
// TODO[Superjomn]: transA cannot be accessed in ConvertLayoutOp.
bool transA = false;
if (transA) {
std::swap(shape[0], shape[1]);
}
ValueTable ha;
std::function<void(int, int)> loadFn;
@@ -1052,8 +1047,6 @@ struct MMA16816ConversionHelper {
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
aTensorTy.getShape().end());
if (op.transA())
std::swap(aShape[0], aShape[1]);
auto dShape = dTensorTy.getShape();
@@ -1462,8 +1455,8 @@ Value DotOpMmaV1ConversionHelper::loadB(
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;
// TODO[Superjomn]: Support the case when isBVec4=false later

View File

@@ -3253,7 +3253,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
// Here we assume the DotOp's operands always comes from shared memory.
auto AShape = A.getType().cast<RankedTensorType>().getShape();
size_t reduceAxis = op.transA() ? 0 : 1;
size_t reduceAxis = 1;
unsigned K = AShape[reduceAxis];
bool isOuter = K == 1;
@@ -3492,22 +3492,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
bool transA = op.transA();
bool transB = op.transB();
// TODO[Superjomn]: order cannot accessed in DotOp.
SmallVector<unsigned> AOrder({1, 0});
SmallVector<unsigned> BOrder({1, 0});
if (transA) {
std::swap(AShape[0], AShape[1]);
std::swap(AOrder[0], AOrder[1]);
}
if (transB) {
std::swap(BShape[0], BShape[1]);
std::swap(BOrder[0], BOrder[0]);
}
bool isARow = AOrder[0] != 0;
bool isBRow = BOrder[0] != 0;
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes

View File

@@ -245,9 +245,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
adaptor.transB());
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, adaptor.c(),
adaptor.allowTF32());
return success();
}
};

View File

@@ -12,21 +12,21 @@ include "triton/Dialect/Triton/IR/TritonOps.td"
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
def CombineDotAddIPattern : Pat<
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddIRevPattern : Pat<
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;

View File

@@ -781,8 +781,7 @@ public:
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32(),
dotOp.transA(), dotOp.transB());
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, oldRetType, newDot.getResult());