testing things...

This commit is contained in:
Phil Tillet
2022-12-09 19:31:34 -08:00
parent fa6dbbff60
commit 58d2867fe6
6 changed files with 105 additions and 30 deletions

View File

@@ -432,13 +432,12 @@ section 9.7.13.4.1 for more details.
let builders = [ let builders = [
AttrBuilder<(ins "unsigned":$opIdx, AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent), [{ "Attribute":$parent), [{
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() && if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){ parent.cast<MmaEncodingAttr>().getVersion() == 1){
llvm::report_fatal_error("DotOperand for MMAv1 must have isMMAv1Row field"); isMMAv1Row = BoolAttr::get(context, true);
return {};
} }
Attribute none; return $_get(context, opIdx, parent, isMMAv1Row);
return $_get(context, opIdx, parent, none);
}]> }]>
]; ];

View File

@@ -1356,6 +1356,20 @@ Value DotOpMmaV1ConversionHelper::loadA(
Value tensor, bool transA, const SharedMemoryObject &smemObj, Value thread, Value tensor, bool transA, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const { Location loc, ConversionPatternRewriter &rewriter) const {
// [1, 0] (isRow = True)
// x x x x || x x x x
// x x x x || x x x x
// stride = [8, 1]
// strideA0 = strideAk = 1
// strideA1 = strideAm = 8
// [0, 1] (isRow = False)
// x x x x || x x x x
// x x x x || x x x x
// stride = [1, 2]
// strideA0 = strideAm = 1
// strideA1 = strideAk = 2
auto *ctx = rewriter.getContext(); auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>(); auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>(); auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
@@ -1364,8 +1378,8 @@ Value DotOpMmaV1ConversionHelper::loadA(
SmallVector<unsigned> order(sharedLayout.getOrder().begin(), SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end()); sharedLayout.getOrder().end());
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); // Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); Value smemBase = smemObj.base;
bool isARow = order[0] != 0; bool isARow = order[0] != 0;
AParam param(isARow); AParam param(isARow);
@@ -1387,6 +1401,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
Value strideA0 = isARow ? strideAK : strideAM; Value strideA0 = isARow ? strideAK : strideAM;
Value strideA1 = isARow ? strideAM : strideAK; Value strideA1 = isARow ? strideAM : strideAK;
smemBase = gep(ptr_ty(f16_ty), smemBase, Value(smemObj.offsets[1]));
int strideRepM = wpt[0] * fpw[0] * 8; int strideRepM = wpt[0] * fpw[0] * 8;
int strideRepK = 1; int strideRepK = 1;
@@ -1401,7 +1416,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
Value offA0 = isARow ? offsetAK : offsetAM; Value offA0 = isARow ? offsetAK : offsetAM;
Value offA1 = isARow ? offsetAM : offsetAK; Value offA1 = isARow ? offsetAM : offsetAK;
Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA)); Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA));
offA0 = add(offA0, cSwizzleOffset); // offA0 = add(offA0, smemObj.offsets[order[0]]);
// offA1 = add(offA1, smemObj.offsets[order[1]]);
SmallVector<Value> offA(numPtrA); SmallVector<Value> offA(numPtrA);
for (int i = 0; i < numPtrA; i++) { for (int i = 0; i < numPtrA; i++) {
Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM))); Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM)));
@@ -1422,6 +1439,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
Type f16PtrTy = ptr_ty(f16_ty); Type f16PtrTy = ptr_ty(f16_ty);
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) { auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
vals[{m, k}] = {val0, val1}; vals[{m, k}] = {val0, val1};
}; };
@@ -1451,7 +1469,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
} }
}; };
unsigned numM = getNumM(shape, order); unsigned numM = getNumM(shape, order);
llvm::outs() << "LOAD A " << numM << " " << NK << "\n";
for (unsigned k = 0; k < NK; k += 4) for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m) for (unsigned m = 0; m < numM / 2; ++m)
loadA(m, k); loadA(m, k);

View File

@@ -3427,6 +3427,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
} }
} else if (!isOuter && mmaLayout.getVersion() == 1 && } else if (!isOuter && mmaLayout.getVersion() == 1 &&
isHMMA) { // tensor core v1 isHMMA) { // tensor core v1
// vprintf("offset 0", smemObj.offsets[0]}, rewriter);
DotOpMmaV1ConversionHelper helper(mmaLayout); DotOpMmaV1ConversionHelper helper(mmaLayout);
bool isMMAv1Row = bool isMMAv1Row =
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue(); dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
@@ -3443,6 +3444,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
} }
if (dotOperandLayout.getOpIdx() == 0) { // operand $a if (dotOperandLayout.getOpIdx() == 0) { // operand $a
// LLVM::vprintf_array(i32_val(0), smemObj.offsets, "offsets ", "%d", rewriter);
// TODO[Superjomn]: transA is not available here. // TODO[Superjomn]: transA is not available here.
bool transA = false; bool transA = false;
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc, res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc,
@@ -4716,6 +4718,47 @@ private:
}); });
} }
void rewriteConvertToDotOperand(ModuleOp mod) {
mod.walk([&](triton::gpu::ConvertLayoutOp cvt){
OpBuilder builder(cvt);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if(auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if(auto srcSharedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return;
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return;
unsigned opIdx = dstDotOperandLayout.getOpIdx();
if(!dstDotOperandLayout.getIsMMAv1Row())
return;
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if((order[0] == 1 && isMMAv1Row) ||
(order[0] == 0 && !isMMAv1Row))
return;
auto newIsRow = BoolAttr::get(cvt.getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
cvt.getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(),
dstType.getElementType(), newDstEncoding);
auto newCvt = builder.create<triton::gpu::ConvertLayoutOp>(
cvt.getLoc(), newDstType, cvt.getOperand());
cvt.replaceAllUsesWith(newCvt.getResult());
cvt.erase();
});
}
void decomposeInsertSliceAsyncOp(ModuleOp mod) { void decomposeInsertSliceAsyncOp(ModuleOp mod) {
AxisInfoAnalysis axisInfoAnalysis(mod.getContext()); AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
axisInfoAnalysis.run(mod); axisInfoAnalysis.run(mod);
@@ -4835,6 +4878,7 @@ public:
// separation between 1/4 is that, step 3 is out of the scope of Dialect // separation between 1/4 is that, step 3 is out of the scope of Dialect
// Conversion, thus we need to make sure the smem is not revised during the // Conversion, thus we need to make sure the smem is not revised during the
// conversion of step 4. // conversion of step 4.
rewriteConvertToDotOperand(mod);
decomposeMmaToDotOperand(mod, numWarps); decomposeMmaToDotOperand(mod, numWarps);
decomposeBlockedToDotOperand(mod); decomposeBlockedToDotOperand(mod);
@@ -4845,6 +4889,7 @@ public:
MembarAnalysis membarPass(&allocation); MembarAnalysis membarPass(&allocation);
membarPass.run(); membarPass.run();
llvm::outs() << mod << "\n";
RewritePatternSet scf_patterns(context); RewritePatternSet scf_patterns(context);
mlir::populateLoopToStdConversionPatterns(scf_patterns); mlir::populateLoopToStdConversionPatterns(scf_patterns);
mlir::ConversionTarget scf_target(*context); mlir::ConversionTarget scf_target(*context);

View File

@@ -713,9 +713,9 @@ public:
} }
}; };
class OptimizeBlockedToDotOperand : public mlir::RewritePattern { class OptimizeConvertToDotOperand : public mlir::RewritePattern {
public: public:
OptimizeBlockedToDotOperand(mlir::MLIRContext *context) OptimizeConvertToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, : RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {} context) {}
@@ -725,18 +725,27 @@ public:
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op); auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>(); auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>(); auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
auto srcBlockedLayout = // order
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>(); ArrayRef<unsigned> order;
if(auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if(auto srcSharedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return failure();
// dot operand output
auto dstDotOperandLayout = auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>(); dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!srcBlockedLayout || !dstDotOperandLayout) if (!dstDotOperandLayout)
return failure(); return failure();
unsigned opIdx = dstDotOperandLayout.getOpIdx(); unsigned opIdx = dstDotOperandLayout.getOpIdx();
if(!dstDotOperandLayout.getIsMMAv1Row()) if(!dstDotOperandLayout.getIsMMAv1Row())
return failure(); return failure();
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue(); bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if((srcBlockedLayout.getOrder()[0] == 1 && isMMAv1Row) || if((order[0] == 1 && isMMAv1Row) ||
(srcBlockedLayout.getOrder()[0] == 0 && !isMMAv1Row)) (order[0] == 0 && !isMMAv1Row))
return failure(); return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row); auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get( auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
@@ -862,7 +871,7 @@ public:
mlir::RewritePatternSet patterns(context); mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeBlockedToShared>(context); patterns.add<OptimizeBlockedToShared>(context);
patterns.add<OptimizeBlockedToDotOperand>(context); // patterns.add<OptimizeConvertToDotOperand>(context);
patterns.add<SimplifyConversion>(context); patterns.add<SimplifyConversion>(context);
patterns.add<DecomposeDotOperand>(context); patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context); patterns.add<RematerializeBackward>(context);
@@ -873,6 +882,7 @@ public:
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure(); signalPassFailure();
} }
} }
}; };

View File

@@ -297,22 +297,22 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
# NOTE this is useful only on Volta GPU. # NOTE this is useful only on Volta GPU.
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [ @pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
# Non-forloop # Non-forloop
[16, 16, 16, 1, 16, 16, 16, False, False], # [16, 16, 16, 1, 16, 16, 16, False, False],
[16, 16, 32, 1, 16, 16, 32, False, False], # [16, 16, 32, 1, 16, 16, 32, False, False],
[32, 16, 32, 1, 32, 16, 32, False, False], # [32, 16, 32, 1, 32, 16, 32, False, False],
[32, 32, 32, 1, 32, 32, 32, False, False], # [32, 32, 32, 1, 32, 32, 32, False, False],
[128, 32, 32, 1, 128, 32, 32, False, False], # [128, 32, 32, 1, 128, 32, 32, False, False],
[128, 32, 32, 1, 128, 32, 32, True, False], # [128, 32, 32, 1, 128, 32, 32, True, False],
[128, 32, 32, 1, 128, 32, 32, True, True], # [128, 32, 32, 1, 128, 32, 32, True, True],
# split-K # # split-K
[16, 16, 32, 1, 16, 16, 16, False, False], # [16, 16, 32, 1, 16, 16, 16, False, False],
[64, 64, 128, 1, 64, 64, 32, False, False], # [64, 64, 128, 1, 64, 64, 32, False, False],
[16, 16, 32, 1, 16, 16, 16, True, False], # [16, 16, 32, 1, 16, 16, 16, True, False],
[16, 16, 32, 1, 16, 16, 16, True, True], # [16, 16, 32, 1, 16, 16, 16, True, True],
[64, 64, 128, 1, 64, 64, 32, True, True], [64, 64, 64, 1, 64, 64, 32, True, False],
]) ])
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B): def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B) test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)

View File

@@ -1402,9 +1402,9 @@ def compile(fn, **kwargs):
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context), "ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)), lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context), "ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)), lambda src: ttir_to_ttgir(src, num_warps, num_stages, 70)),
"llir": (lambda path: Path(path).read_bytes(), "llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs, capability)), lambda src: ttgir_to_llir(src, extern_libs, 70)),
"ptx": (lambda path: Path(path).read_text(), "ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, capability)), lambda src: llir_to_ptx(src, capability)),
"cubin": (lambda path: Path(path).read_bytes(), "cubin": (lambda path: Path(path).read_bytes(),