[Triton-MLIR][BACKEND] make MMAv1 splitk works (#960)

This commit is contained in:
Yan Chunwei
2022-12-07 16:58:38 +08:00
committed by GitHub
parent b2b793dfb5
commit 4eab9dcedf
6 changed files with 58 additions and 24 deletions

View File

@@ -88,8 +88,10 @@ jobs:
- name: Run python tests on V100
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
run: |
# TODO[Superjomn]: Remove the forloop-unroll setting after pipeline pass works
cd python/tests
pytest test_gemm.py::test_gemm_no_scf_for_mmav1
export TRITON_STATIC_LOOP_UNROLLING=1
pytest test_gemm.py::test_gemm_for_mmav1
- name: Run CXX unittests
run: |

View File

@@ -72,6 +72,11 @@ struct DotOpMmaV1ConversionHelper {
bool isARow = order[0] != 0;
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
// TODO[Superjomn]: Support the case when isAVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with
// different ld vector width.
isAVec4 = true;
int packSize0 = (isARow || isAVec4) ? 1 : 2;
SmallVector<int> fpw({2, 2, 1});
@@ -98,6 +103,11 @@ struct DotOpMmaV1ConversionHelper {
auto order = getOrder();
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;
// TODO[Superjomn]: Support the case when isBVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with
// different ld vector width.
isBVec4 = true;
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
@@ -1455,7 +1465,6 @@ Value DotOpMmaV1ConversionHelper::loadB(
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;

View File

@@ -3555,14 +3555,18 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter);
auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter);
// initialize accumulators
// Initialize accumulators with external values, the acc holds the accumulator
// value that is shared between the MMA instructions inside a DotOp, we can
// call the order of the values the accumulator-internal order.
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
size_t resSize = acc.size();
// The resVals holds the final result of the DotOp.
// NOTE The current order of resVals is different from acc, we call it the
// accumulator-external order. and
SmallVector<Value> resVals(resSize);
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
auto ha = has.at({m, k});
auto hb = hbs.at({n, k});
auto getIdx = [&](int m, int n) {
std::vector<size_t> idx{{
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
(m * 2 + 0) + (n * 4 + 1) * numM,
@@ -3573,8 +3577,29 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
(m * 2 + 1) + (n * 4 + 2) * numM, // row3
(m * 2 + 1) + (n * 4 + 3) * numM,
}};
return idx;
};
{ // convert the acc's value from accumuator-external order to
// accumulator-internal order.
SmallVector<Value> accInit(acc.size());
for (unsigned m = 0; m < numM / 2; ++m)
for (unsigned n = 0; n < numN / 2; ++n) {
auto idx = getIdx(m, n);
for (unsigned i = 0; i < 8; ++i)
accInit[idx[i]] = acc[(m * numN / 2 + n) * 8 + i];
}
acc = accInit;
}
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
auto ha = has.at({m, k});
auto hb = hbs.at({n, k});
PTXBuilder builder;
auto idx = getIdx(m, n);
auto *resOprs = builder.newListOperand(8, "=f");
auto *AOprs = builder.newListOperand({
@@ -3606,8 +3631,6 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
for (unsigned i = 0; i < 8; i++) {
Value elem = extract_val(f32_ty, res, getIntAttr(i));
acc[idx[i]] = elem;
// TODO[goostavz]: double confirm this when m/n/k = [32, 32, x] has been
// verified before MMA
resVals[(m * numN / 2 + n) * 8 + i] = elem;
}
};

View File

@@ -1382,10 +1382,11 @@ void init_triton_translation(py::module &m) {
llvm::SMDiagnostic error;
std::unique_ptr<llvm::Module> module =
llvm::parseIR(buffer->getMemBufferRef(), error, context);
if (!module)
if (!module) {
llvm::report_fatal_error(
"failed to parse IR: " + error.getMessage() +
"lineno: " + std::to_string(error.getLineNo()));
}
// translate module to PTX
auto ptxCode =

View File

@@ -295,18 +295,17 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
# NOTE this is useful only on Volta GPU.
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
(shape, num_warps, trans_a, trans_b)
for shape in [
[16, 16, 16],
[16, 16, 32],
[32, 16, 16],
[32, 32, 32],
[128, 16, 16],
]
for num_warps in [1]
for trans_a in [False]
for trans_b in [False]
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
# Non-forloop
[16, 16, 16, 1, 16, 16, 16, False, False],
[16, 16, 32, 1, 16, 16, 32, False, False],
[32, 16, 32, 1, 32, 16, 32, False, False],
[32, 32, 32, 1, 32, 32, 32, False, False],
[128, 32, 32, 1, 128, 32, 32, False, False],
# split-K
[16, 16, 32, 1, 16, 16, 16, False, False],
[64, 64, 128, 1, 64, 64, 32, False, False],
])
def test_gemm_no_scf_for_mmav1(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B)
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)

View File

@@ -583,7 +583,7 @@ class CodeGenerator(ast.NodeVisitor):
isinstance(step, triton.language.constexpr):
sta_range = iterator(lb.value, ub.value, step.value)
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
if static_unrolling and len(range) <= 10:
if static_unrolling and len(sta_range) <= 10:
for i in sta_range:
self.lscope[node.target.id] = triton.language.constexpr(i)
self.visit_compound_statement(node.body)