From 58d2867fe6bfa466b39e749fde973dcf37b8abca Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Fri, 9 Dec 2022 19:31:34 -0800 Subject: [PATCH] testing things... --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 7 ++- lib/Conversion/TritonGPUToLLVM/DotHelpers.h | 27 +++++++++-- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 45 +++++++++++++++++++ lib/Dialect/TritonGPU/Transforms/Combine.cpp | 26 +++++++---- python/tests/test_gemm.py | 26 +++++------ python/triton/compiler.py | 4 +- 6 files changed, 105 insertions(+), 30 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index cd58e5858..c658667ce 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -432,13 +432,12 @@ section 9.7.13.4.1 for more details. let builders = [ AttrBuilder<(ins "unsigned":$opIdx, "Attribute":$parent), [{ + Attribute isMMAv1Row; if(parent.isa() && parent.cast().getVersion() == 1){ - llvm::report_fatal_error("DotOperand for MMAv1 must have isMMAv1Row field"); - return {}; + isMMAv1Row = BoolAttr::get(context, true); } - Attribute none; - return $_get(context, opIdx, parent, none); + return $_get(context, opIdx, parent, isMMAv1Row); }]> ]; diff --git a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h index 8703bebcb..30c05debc 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h +++ b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h @@ -1356,6 +1356,20 @@ Value DotOpMmaV1ConversionHelper::loadA( Value tensor, bool transA, const SharedMemoryObject &smemObj, Value thread, Location loc, ConversionPatternRewriter &rewriter) const { + // [1, 0] (isRow = True) + // x x x x || x x x x + // x x x x || x x x x + // stride = [8, 1] + // strideA0 = strideAk = 1 + // strideA1 = strideAm = 8 + + // [0, 1] (isRow = False) + // x x x x || x x x x + // x x x x || x x x x + // stride = [1, 2] + // strideA0 = strideAm = 1 + // strideA1 = strideAk = 2 + auto *ctx = rewriter.getContext(); auto tensorTy = tensor.getType().cast(); auto sharedLayout = tensorTy.getEncoding().cast(); @@ -1364,8 +1378,8 @@ Value DotOpMmaV1ConversionHelper::loadA( SmallVector order(sharedLayout.getOrder().begin(), sharedLayout.getOrder().end()); - Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); - Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); + // Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); + Value smemBase = smemObj.base; bool isARow = order[0] != 0; AParam param(isARow); @@ -1387,6 +1401,7 @@ Value DotOpMmaV1ConversionHelper::loadA( Value strideA0 = isARow ? strideAK : strideAM; Value strideA1 = isARow ? strideAM : strideAK; + smemBase = gep(ptr_ty(f16_ty), smemBase, Value(smemObj.offsets[1])); int strideRepM = wpt[0] * fpw[0] * 8; int strideRepK = 1; @@ -1401,7 +1416,9 @@ Value DotOpMmaV1ConversionHelper::loadA( Value offA0 = isARow ? offsetAK : offsetAM; Value offA1 = isARow ? offsetAM : offsetAK; Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA)); - offA0 = add(offA0, cSwizzleOffset); + // offA0 = add(offA0, smemObj.offsets[order[0]]); + // offA1 = add(offA1, smemObj.offsets[order[1]]); + SmallVector offA(numPtrA); for (int i = 0; i < numPtrA; i++) { Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM))); @@ -1422,6 +1439,7 @@ Value DotOpMmaV1ConversionHelper::loadA( Type f16PtrTy = ptr_ty(f16_ty); + auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) { vals[{m, k}] = {val0, val1}; }; @@ -1451,7 +1469,10 @@ Value DotOpMmaV1ConversionHelper::loadA( } }; + unsigned numM = getNumM(shape, order); + llvm::outs() << "LOAD A " << numM << " " << NK << "\n"; + for (unsigned k = 0; k < NK; k += 4) for (unsigned m = 0; m < numM / 2; ++m) loadA(m, k); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 89917a494..3ca682024 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -3427,6 +3427,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA( } } else if (!isOuter && mmaLayout.getVersion() == 1 && isHMMA) { // tensor core v1 + // vprintf("offset 0", smemObj.offsets[0]}, rewriter); DotOpMmaV1ConversionHelper helper(mmaLayout); bool isMMAv1Row = dotOperandLayout.getIsMMAv1Row().cast().getValue(); @@ -3443,6 +3444,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA( } if (dotOperandLayout.getOpIdx() == 0) { // operand $a + // LLVM::vprintf_array(i32_val(0), smemObj.offsets, "offsets ", "%d", rewriter); // TODO[Superjomn]: transA is not available here. bool transA = false; res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc, @@ -4716,6 +4718,47 @@ private: }); } + void rewriteConvertToDotOperand(ModuleOp mod) { + mod.walk([&](triton::gpu::ConvertLayoutOp cvt){ + OpBuilder builder(cvt); + auto srcType = cvt.getOperand().getType().cast(); + auto dstType = cvt.getResult().getType().cast(); + // order + ArrayRef order; + if(auto srcBlockedLayout = + srcType.getEncoding().dyn_cast()) + order = srcBlockedLayout.getOrder(); + else if(auto srcSharedLayout = + srcType.getEncoding().dyn_cast()) + order = srcSharedLayout.getOrder(); + else + return; + // dot operand output + auto dstDotOperandLayout = + dstType.getEncoding().dyn_cast(); + if (!dstDotOperandLayout) + return; + unsigned opIdx = dstDotOperandLayout.getOpIdx(); + if(!dstDotOperandLayout.getIsMMAv1Row()) + return; + bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast().getValue(); + if((order[0] == 1 && isMMAv1Row) || + (order[0] == 0 && !isMMAv1Row)) + return; + auto newIsRow = BoolAttr::get(cvt.getContext(), !isMMAv1Row); + auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get( + cvt.getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(), + newIsRow); + auto newDstType = RankedTensorType::get( + dstType.getShape(), + dstType.getElementType(), newDstEncoding); + auto newCvt = builder.create( + cvt.getLoc(), newDstType, cvt.getOperand()); + cvt.replaceAllUsesWith(newCvt.getResult()); + cvt.erase(); + }); + } + void decomposeInsertSliceAsyncOp(ModuleOp mod) { AxisInfoAnalysis axisInfoAnalysis(mod.getContext()); axisInfoAnalysis.run(mod); @@ -4835,6 +4878,7 @@ public: // separation between 1/4 is that, step 3 is out of the scope of Dialect // Conversion, thus we need to make sure the smem is not revised during the // conversion of step 4. + rewriteConvertToDotOperand(mod); decomposeMmaToDotOperand(mod, numWarps); decomposeBlockedToDotOperand(mod); @@ -4845,6 +4889,7 @@ public: MembarAnalysis membarPass(&allocation); membarPass.run(); + llvm::outs() << mod << "\n"; RewritePatternSet scf_patterns(context); mlir::populateLoopToStdConversionPatterns(scf_patterns); mlir::ConversionTarget scf_target(*context); diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 23d9c1b80..e762b0b88 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -713,9 +713,9 @@ public: } }; -class OptimizeBlockedToDotOperand : public mlir::RewritePattern { +class OptimizeConvertToDotOperand : public mlir::RewritePattern { public: - OptimizeBlockedToDotOperand(mlir::MLIRContext *context) + OptimizeConvertToDotOperand(mlir::MLIRContext *context) : RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, context) {} @@ -725,18 +725,27 @@ public: auto cvt = cast(op); auto srcType = cvt.getOperand().getType().cast(); auto dstType = cvt.getResult().getType().cast(); - auto srcBlockedLayout = - srcType.getEncoding().dyn_cast(); + // order + ArrayRef order; + if(auto srcBlockedLayout = + srcType.getEncoding().dyn_cast()) + order = srcBlockedLayout.getOrder(); + else if(auto srcSharedLayout = + srcType.getEncoding().dyn_cast()) + order = srcSharedLayout.getOrder(); + else + return failure(); + // dot operand output auto dstDotOperandLayout = dstType.getEncoding().dyn_cast(); - if (!srcBlockedLayout || !dstDotOperandLayout) + if (!dstDotOperandLayout) return failure(); unsigned opIdx = dstDotOperandLayout.getOpIdx(); if(!dstDotOperandLayout.getIsMMAv1Row()) return failure(); bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast().getValue(); - if((srcBlockedLayout.getOrder()[0] == 1 && isMMAv1Row) || - (srcBlockedLayout.getOrder()[0] == 0 && !isMMAv1Row)) + if((order[0] == 1 && isMMAv1Row) || + (order[0] == 0 && !isMMAv1Row)) return failure(); auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row); auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get( @@ -862,7 +871,7 @@ public: mlir::RewritePatternSet patterns(context); patterns.add(context); - patterns.add(context); + // patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); @@ -873,6 +882,7 @@ public: if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); } + } }; diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 012e0771d..1908a29be 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -297,22 +297,22 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32): # NOTE this is useful only on Volta GPU. @pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [ # Non-forloop - [16, 16, 16, 1, 16, 16, 16, False, False], - [16, 16, 32, 1, 16, 16, 32, False, False], - [32, 16, 32, 1, 32, 16, 32, False, False], - [32, 32, 32, 1, 32, 32, 32, False, False], - [128, 32, 32, 1, 128, 32, 32, False, False], + # [16, 16, 16, 1, 16, 16, 16, False, False], + # [16, 16, 32, 1, 16, 16, 32, False, False], + # [32, 16, 32, 1, 32, 16, 32, False, False], + # [32, 32, 32, 1, 32, 32, 32, False, False], + # [128, 32, 32, 1, 128, 32, 32, False, False], - [128, 32, 32, 1, 128, 32, 32, True, False], - [128, 32, 32, 1, 128, 32, 32, True, True], + # [128, 32, 32, 1, 128, 32, 32, True, False], + # [128, 32, 32, 1, 128, 32, 32, True, True], - # split-K - [16, 16, 32, 1, 16, 16, 16, False, False], - [64, 64, 128, 1, 64, 64, 32, False, False], + # # split-K + # [16, 16, 32, 1, 16, 16, 16, False, False], + # [64, 64, 128, 1, 64, 64, 32, False, False], - [16, 16, 32, 1, 16, 16, 16, True, False], - [16, 16, 32, 1, 16, 16, 16, True, True], - [64, 64, 128, 1, 64, 64, 32, True, True], + # [16, 16, 32, 1, 16, 16, 16, True, False], + # [16, 16, 32, 1, 16, 16, 16, True, True], + [64, 64, 64, 1, 64, 64, 32, True, False], ]) def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B): test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 7d442b9f0..1821b6c90 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1402,9 +1402,9 @@ def compile(fn, **kwargs): "ttir": (lambda path: _triton.ir.parse_mlir_module(path, context), lambda src: ast_to_ttir(src, signature, configs[0], constants)), "ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context), - lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)), + lambda src: ttir_to_ttgir(src, num_warps, num_stages, 70)), "llir": (lambda path: Path(path).read_bytes(), - lambda src: ttgir_to_llir(src, extern_libs, capability)), + lambda src: ttgir_to_llir(src, extern_libs, 70)), "ptx": (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, capability)), "cubin": (lambda path: Path(path).read_bytes(),