Merge branch 'triton-mlir' into keren/insert-slice-other-nonzero

This commit is contained in:
Jokeren
2022-12-06 13:25:20 -08:00
14 changed files with 55 additions and 65 deletions

View File

@@ -942,11 +942,6 @@ struct MMA16816ConversionHelper {
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
aTensorTy.getShape().end());
// TODO[Superjomn]: transA cannot be accessed in ConvertLayoutOp.
bool transA = false;
if (transA) {
std::swap(shape[0], shape[1]);
}
ValueTable ha;
std::function<void(int, int)> loadFn;
@@ -1052,8 +1047,6 @@ struct MMA16816ConversionHelper {
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
aTensorTy.getShape().end());
if (op.transA())
std::swap(aShape[0], aShape[1]);
auto dShape = dTensorTy.getShape();
@@ -1462,8 +1455,8 @@ Value DotOpMmaV1ConversionHelper::loadB(
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;
// TODO[Superjomn]: Support the case when isBVec4=false later

View File

@@ -3253,7 +3253,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
// Here we assume the DotOp's operands always comes from shared memory.
auto AShape = A.getType().cast<RankedTensorType>().getShape();
size_t reduceAxis = op.transA() ? 0 : 1;
size_t reduceAxis = 1;
unsigned K = AShape[reduceAxis];
bool isOuter = K == 1;
@@ -3492,22 +3492,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
bool transA = op.transA();
bool transB = op.transB();
// TODO[Superjomn]: order cannot accessed in DotOp.
SmallVector<unsigned> AOrder({1, 0});
SmallVector<unsigned> BOrder({1, 0});
if (transA) {
std::swap(AShape[0], AShape[1]);
std::swap(AOrder[0], AOrder[1]);
}
if (transB) {
std::swap(BShape[0], BShape[1]);
std::swap(BOrder[0], BOrder[0]);
}
bool isARow = AOrder[0] != 0;
bool isBRow = BOrder[0] != 0;
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes

View File

@@ -245,9 +245,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
adaptor.transB());
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, adaptor.c(),
adaptor.allowTF32());
return success();
}
};