diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index d54e7061e..62c9e1dae 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -88,8 +88,10 @@ jobs: - name: Run python tests on V100 if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}} run: | + # TODO[Superjomn]: Remove the forloop-unroll setting after pipeline pass works cd python/tests - pytest test_gemm.py::test_gemm_no_scf_for_mmav1 + export TRITON_STATIC_LOOP_UNROLLING=1 + pytest test_gemm.py::test_gemm_for_mmav1 - name: Run CXX unittests run: | diff --git a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h index 3943bc1b8..96a9b764c 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h +++ b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h @@ -72,6 +72,11 @@ struct DotOpMmaV1ConversionHelper { bool isARow = order[0] != 0; bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes + // TODO[Superjomn]: Support the case when isAVec4=false later + // Currently, we only support ld.v2, for the mma layout varies with + // different ld vector width. + isAVec4 = true; + int packSize0 = (isARow || isAVec4) ? 1 : 2; SmallVector fpw({2, 2, 1}); @@ -98,6 +103,11 @@ struct DotOpMmaV1ConversionHelper { auto order = getOrder(); bool isBRow = order[0] != 0; bool isBVec4 = isBRow && shape[order[0]] <= 16; + // TODO[Superjomn]: Support the case when isBVec4=false later + // Currently, we only support ld.v2, for the mma layout varies with + // different ld vector width. + isBVec4 = true; + int packSize1 = (isBRow && !isBVec4) ? 2 : 1; SmallVector fpw({2, 2, 1}); SmallVector rep({0, 2 * packSize1, 1}); // pad M with 0 @@ -1455,7 +1465,6 @@ Value DotOpMmaV1ConversionHelper::loadB( SmallVector order(sharedLayout.getOrder().begin(), sharedLayout.getOrder().end()); - Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); bool isBRow = order[0] != 0; bool isBVec4 = isBRow && shape[order[0]] <= 16; diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 8af05b0f3..dc32f7825 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -3555,14 +3555,18 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, auto has = helper.extractLoadedOperand(loadedA, NK, rewriter); auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter); - // initialize accumulators + // Initialize accumulators with external values, the acc holds the accumulator + // value that is shared between the MMA instructions inside a DotOp, we can + // call the order of the values the accumulator-internal order. SmallVector acc = getElementsFromStruct(loc, loadedC, rewriter); size_t resSize = acc.size(); + + // The resVals holds the final result of the DotOp. + // NOTE The current order of resVals is different from acc, we call it the + // accumulator-external order. and SmallVector resVals(resSize); - auto callMMA = [&](unsigned m, unsigned n, unsigned k) { - auto ha = has.at({m, k}); - auto hb = hbs.at({n, k}); + auto getIdx = [&](int m, int n) { std::vector idx{{ (m * 2 + 0) + (n * 4 + 0) * numM, // row0 (m * 2 + 0) + (n * 4 + 1) * numM, @@ -3573,8 +3577,29 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, (m * 2 + 1) + (n * 4 + 2) * numM, // row3 (m * 2 + 1) + (n * 4 + 3) * numM, }}; + return idx; + }; + + { // convert the acc's value from accumuator-external order to + // accumulator-internal order. + SmallVector accInit(acc.size()); + + for (unsigned m = 0; m < numM / 2; ++m) + for (unsigned n = 0; n < numN / 2; ++n) { + auto idx = getIdx(m, n); + for (unsigned i = 0; i < 8; ++i) + accInit[idx[i]] = acc[(m * numN / 2 + n) * 8 + i]; + } + + acc = accInit; + } + + auto callMMA = [&](unsigned m, unsigned n, unsigned k) { + auto ha = has.at({m, k}); + auto hb = hbs.at({n, k}); PTXBuilder builder; + auto idx = getIdx(m, n); auto *resOprs = builder.newListOperand(8, "=f"); auto *AOprs = builder.newListOperand({ @@ -3606,8 +3631,6 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, for (unsigned i = 0; i < 8; i++) { Value elem = extract_val(f32_ty, res, getIntAttr(i)); acc[idx[i]] = elem; - // TODO[goostavz]: double confirm this when m/n/k = [32, 32, x] has been - // verified before MMA resVals[(m * numN / 2 + n) * 8 + i] = elem; } }; diff --git a/python/src/triton.cc b/python/src/triton.cc index 9d0a7f0d1..167c53061 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1382,10 +1382,11 @@ void init_triton_translation(py::module &m) { llvm::SMDiagnostic error; std::unique_ptr module = llvm::parseIR(buffer->getMemBufferRef(), error, context); - if (!module) + if (!module) { llvm::report_fatal_error( "failed to parse IR: " + error.getMessage() + "lineno: " + std::to_string(error.getLineNo())); + } // translate module to PTX auto ptxCode = diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 3555359dc..8af333c70 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -295,18 +295,17 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32): # NOTE this is useful only on Volta GPU. -@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [ - (shape, num_warps, trans_a, trans_b) - for shape in [ - [16, 16, 16], - [16, 16, 32], - [32, 16, 16], - [32, 32, 32], - [128, 16, 16], - ] - for num_warps in [1] - for trans_a in [False] - for trans_b in [False] +@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [ + # Non-forloop + [16, 16, 16, 1, 16, 16, 16, False, False], + [16, 16, 32, 1, 16, 16, 32, False, False], + [32, 16, 32, 1, 32, 16, 32, False, False], + [32, 32, 32, 1, 32, 32, 32, False, False], + [128, 32, 32, 1, 128, 32, 32, False, False], + + # split-K + [16, 16, 32, 1, 16, 16, 16, False, False], + [64, 64, 128, 1, 64, 64, 32, False, False], ]) -def test_gemm_no_scf_for_mmav1(SHAPE, NUM_WARPS, TRANS_A, TRANS_B): - test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B) +def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B): + test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index d66dbfd50..7d442b9f0 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -583,7 +583,7 @@ class CodeGenerator(ast.NodeVisitor): isinstance(step, triton.language.constexpr): sta_range = iterator(lb.value, ub.value, step.value) static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False) - if static_unrolling and len(range) <= 10: + if static_unrolling and len(sta_range) <= 10: for i in sta_range: self.lscope[node.target.id] = triton.language.constexpr(i) self.visit_compound_statement(node.body)