Merge branch 'triton-mlir' into keren/insert-slice-other-nonzero
This commit is contained in:
@@ -942,11 +942,6 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
|
||||
aTensorTy.getShape().end());
|
||||
// TODO[Superjomn]: transA cannot be accessed in ConvertLayoutOp.
|
||||
bool transA = false;
|
||||
if (transA) {
|
||||
std::swap(shape[0], shape[1]);
|
||||
}
|
||||
|
||||
ValueTable ha;
|
||||
std::function<void(int, int)> loadFn;
|
||||
@@ -1052,8 +1047,6 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
|
||||
aTensorTy.getShape().end());
|
||||
if (op.transA())
|
||||
std::swap(aShape[0], aShape[1]);
|
||||
|
||||
auto dShape = dTensorTy.getShape();
|
||||
|
||||
@@ -1462,8 +1455,8 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
|
||||
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
|
||||
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
bool isBRow = order[0] != 0;
|
||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
||||
// TODO[Superjomn]: Support the case when isBVec4=false later
|
||||
|
@@ -3253,7 +3253,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
|
||||
// Here we assume the DotOp's operands always comes from shared memory.
|
||||
auto AShape = A.getType().cast<RankedTensorType>().getShape();
|
||||
size_t reduceAxis = op.transA() ? 0 : 1;
|
||||
size_t reduceAxis = 1;
|
||||
unsigned K = AShape[reduceAxis];
|
||||
bool isOuter = K == 1;
|
||||
|
||||
@@ -3492,22 +3492,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
auto DShape = DTensorTy.getShape();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
bool transA = op.transA();
|
||||
bool transB = op.transB();
|
||||
|
||||
// TODO[Superjomn]: order cannot accessed in DotOp.
|
||||
SmallVector<unsigned> AOrder({1, 0});
|
||||
SmallVector<unsigned> BOrder({1, 0});
|
||||
|
||||
if (transA) {
|
||||
std::swap(AShape[0], AShape[1]);
|
||||
std::swap(AOrder[0], AOrder[1]);
|
||||
}
|
||||
if (transB) {
|
||||
std::swap(BShape[0], BShape[1]);
|
||||
std::swap(BOrder[0], BOrder[0]);
|
||||
}
|
||||
|
||||
bool isARow = AOrder[0] != 0;
|
||||
bool isBRow = BOrder[0] != 0;
|
||||
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
||||
|
@@ -245,9 +245,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
bType.getElementType(), encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
|
||||
adaptor.transB());
|
||||
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, adaptor.c(),
|
||||
adaptor.allowTF32());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
Reference in New Issue
Block a user