Merge branch 'triton-mlir' into keren/insert-slice-other-nonzero

This commit is contained in:
Jokeren
2022-12-06 13:25:20 -08:00
14 changed files with 55 additions and 65 deletions

View File

@@ -3253,7 +3253,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
// Here we assume the DotOp's operands always comes from shared memory.
auto AShape = A.getType().cast<RankedTensorType>().getShape();
size_t reduceAxis = op.transA() ? 0 : 1;
size_t reduceAxis = 1;
unsigned K = AShape[reduceAxis];
bool isOuter = K == 1;
@@ -3492,22 +3492,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
bool transA = op.transA();
bool transB = op.transB();
// TODO[Superjomn]: order cannot accessed in DotOp.
SmallVector<unsigned> AOrder({1, 0});
SmallVector<unsigned> BOrder({1, 0});
if (transA) {
std::swap(AShape[0], AShape[1]);
std::swap(AOrder[0], AOrder[1]);
}
if (transB) {
std::swap(BShape[0], BShape[1]);
std::swap(BOrder[0], BOrder[0]);
}
bool isARow = AOrder[0] != 0;
bool isBRow = BOrder[0] != 0;
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes