4 Commits

Author SHA1 Message Date
Phil Tillet
58d2867fe6 testing things... 2022-12-09 19:31:34 -08:00
Phil Tillet
fa6dbbff60 . 2022-12-08 21:36:47 -08:00
Yan Chunwei
13644e7ac4 adapt isMMAv1Row in backend (#969) 2022-12-08 20:52:43 -08:00
Phil Tillet
0d27912554 [OPTIMIZER] Added isRow attribute for dotOperandEncoding of MMAv1 layout 2022-12-08 18:31:16 -08:00
8 changed files with 213 additions and 27 deletions

View File

@@ -416,15 +416,35 @@ In TritonGPU dialect, considering `d = tt.dot a, b, c`
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
a's opIdx is 0, b's opIdx is 1.
The parend field in DotOperandEncodingAttr is the layout of d.
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
section 9.7.13.4.1 for more details.
}];
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent
"Attribute":$parent,
"Attribute":$isMMAv1Row
);
let builders = [
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent), [{
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
isMMAv1Row = BoolAttr::get(context, true);
}
return $_get(context, opIdx, parent, isMMAv1Row);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration;
}
#endif

View File

@@ -1356,6 +1356,20 @@ Value DotOpMmaV1ConversionHelper::loadA(
Value tensor, bool transA, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const {
// [1, 0] (isRow = True)
// x x x x || x x x x
// x x x x || x x x x
// stride = [8, 1]
// strideA0 = strideAk = 1
// strideA1 = strideAm = 8
// [0, 1] (isRow = False)
// x x x x || x x x x
// x x x x || x x x x
// stride = [1, 2]
// strideA0 = strideAm = 1
// strideA1 = strideAk = 2
auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
@@ -1364,8 +1378,8 @@ Value DotOpMmaV1ConversionHelper::loadA(
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
// Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
Value smemBase = smemObj.base;
bool isARow = order[0] != 0;
AParam param(isARow);
@@ -1387,6 +1401,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
Value strideA0 = isARow ? strideAK : strideAM;
Value strideA1 = isARow ? strideAM : strideAK;
smemBase = gep(ptr_ty(f16_ty), smemBase, Value(smemObj.offsets[1]));
int strideRepM = wpt[0] * fpw[0] * 8;
int strideRepK = 1;
@@ -1401,7 +1416,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
Value offA0 = isARow ? offsetAK : offsetAM;
Value offA1 = isARow ? offsetAM : offsetAK;
Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA));
offA0 = add(offA0, cSwizzleOffset);
// offA0 = add(offA0, smemObj.offsets[order[0]]);
// offA1 = add(offA1, smemObj.offsets[order[1]]);
SmallVector<Value> offA(numPtrA);
for (int i = 0; i < numPtrA; i++) {
Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM)));
@@ -1422,6 +1439,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
Type f16PtrTy = ptr_ty(f16_ty);
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
vals[{m, k}] = {val0, val1};
};
@@ -1451,7 +1469,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
}
};
unsigned numM = getNumM(shape, order);
llvm::outs() << "LOAD A " << numM << " " << NK << "\n";
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
loadA(m, k);

View File

@@ -3427,8 +3427,24 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
}
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
isHMMA) { // tensor core v1
// vprintf("offset 0", smemObj.offsets[0]}, rewriter);
DotOpMmaV1ConversionHelper helper(mmaLayout);
bool isMMAv1Row =
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto srcSharedLayout = src.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<SharedEncodingAttr>();
// Can only convert [1, 0] to row or [0, 1] to col for now
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) {
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
return Value();
}
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
// LLVM::vprintf_array(i32_val(0), smemObj.offsets, "offsets ", "%d", rewriter);
// TODO[Superjomn]: transA is not available here.
bool transA = false;
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc,
@@ -3540,6 +3556,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
.cast<RankedTensorType>()
.getEncoding()
.cast<MmaEncodingAttr>();
auto ALayout = A.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto BLayout = B.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
@@ -3551,12 +3575,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
// TODO[Superjomn]: order cannot accessed in DotOp.
SmallVector<unsigned> AOrder({1, 0});
SmallVector<unsigned> BOrder({1, 0});
bool isARow = AOrder[0] != 0;
bool isBRow = BOrder[0] != 0;
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
// TODO[Superjomn]: ld.v4 is not supported.
@@ -4698,6 +4718,47 @@ private:
});
}
void rewriteConvertToDotOperand(ModuleOp mod) {
mod.walk([&](triton::gpu::ConvertLayoutOp cvt){
OpBuilder builder(cvt);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if(auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if(auto srcSharedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return;
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return;
unsigned opIdx = dstDotOperandLayout.getOpIdx();
if(!dstDotOperandLayout.getIsMMAv1Row())
return;
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if((order[0] == 1 && isMMAv1Row) ||
(order[0] == 0 && !isMMAv1Row))
return;
auto newIsRow = BoolAttr::get(cvt.getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
cvt.getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(),
dstType.getElementType(), newDstEncoding);
auto newCvt = builder.create<triton::gpu::ConvertLayoutOp>(
cvt.getLoc(), newDstType, cvt.getOperand());
cvt.replaceAllUsesWith(newCvt.getResult());
cvt.erase();
});
}
void decomposeInsertSliceAsyncOp(ModuleOp mod) {
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
axisInfoAnalysis.run(mod);
@@ -4817,6 +4878,7 @@ public:
// separation between 1/4 is that, step 3 is out of the scope of Dialect
// Conversion, thus we need to make sure the smem is not revised during the
// conversion of step 4.
rewriteConvertToDotOperand(mod);
decomposeMmaToDotOperand(mod, numWarps);
decomposeBlockedToDotOperand(mod);
@@ -4827,6 +4889,7 @@ public:
MembarAnalysis membarPass(&allocation);
membarPass.run();
llvm::outs() << mod << "\n";
RewritePatternSet scf_patterns(context);
mlir::populateLoopToStdConversionPatterns(scf_patterns);
mlir::ConversionTarget scf_target(*context);

View File

@@ -589,15 +589,24 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
isMMAv1Row = attrs.get("isMMAv1Row");
if(!isMMAv1Row)
llvm::report_fatal_error("isMMAv1Row attribute is missing");
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent);
parent, isMMAv1Row);
}
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent() << "}>";
<< "parent = " << getParent();
if(getIsMMAv1Row())
printer << ", isMMAv1Row = " << getIsMMAv1Row();
printer << "}>";
}
//===----------------------------------------------------------------------===//

View File

@@ -713,6 +713,55 @@ public:
}
};
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
public:
OptimizeConvertToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if(auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if(auto srcSharedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return failure();
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return failure();
unsigned opIdx = dstDotOperandLayout.getOpIdx();
if(!dstDotOperandLayout.getIsMMAv1Row())
return failure();
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if((order[0] == 1 && isMMAv1Row) ||
(order[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(),
dstType.getElementType(), newDstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newDstType, cvt.getOperand());
rewriter.replaceOp(op, newCvt.getResult());
return success();
}
};
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
@@ -770,14 +819,28 @@ public:
Value b = dotOp.b();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
Attribute isMMAv1RowA;
Attribute isMMAv1RowB;
if(version == 1){
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding()));
newRetType.getEncoding(),
isMMAv1RowA));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding()));
newRetType.getEncoding(),
isMMAv1RowB));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
@@ -808,6 +871,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeBlockedToShared>(context);
// patterns.add<OptimizeConvertToDotOperand>(context);
patterns.add<SimplifyConversion>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);
@@ -818,6 +882,7 @@ public:
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
}
};

View File

@@ -297,15 +297,23 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
# NOTE this is useful only on Volta GPU.
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
# Non-forloop
[16, 16, 16, 1, 16, 16, 16, False, False],
[16, 16, 32, 1, 16, 16, 32, False, False],
[32, 16, 32, 1, 32, 16, 32, False, False],
[32, 32, 32, 1, 32, 32, 32, False, False],
[128, 32, 32, 1, 128, 32, 32, False, False],
# [16, 16, 16, 1, 16, 16, 16, False, False],
# [16, 16, 32, 1, 16, 16, 32, False, False],
# [32, 16, 32, 1, 32, 16, 32, False, False],
# [32, 32, 32, 1, 32, 32, 32, False, False],
# [128, 32, 32, 1, 128, 32, 32, False, False],
# split-K
[16, 16, 32, 1, 16, 16, 16, False, False],
[64, 64, 128, 1, 64, 64, 32, False, False],
# [128, 32, 32, 1, 128, 32, 32, True, False],
# [128, 32, 32, 1, 128, 32, 32, True, True],
# # split-K
# [16, 16, 32, 1, 16, 16, 16, False, False],
# [64, 64, 128, 1, 64, 64, 32, False, False],
# [16, 16, 32, 1, 16, 16, 16, True, False],
# [16, 16, 32, 1, 16, 16, 16, True, True],
[64, 64, 64, 1, 64, 64, 32, True, False],
])
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)

View File

@@ -1402,9 +1402,9 @@ def compile(fn, **kwargs):
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, 70)),
"llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
lambda src: ttgir_to_llir(src, extern_libs, 70)),
"ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, capability)),
"cubin": (lambda path: Path(path).read_bytes(),

View File

@@ -879,8 +879,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {