[Triton-MLIR][BACKEND] make MMAv1 splitk works (#960)
This commit is contained in:
4
.github/workflows/integration-tests.yml
vendored
4
.github/workflows/integration-tests.yml
vendored
@@ -88,8 +88,10 @@ jobs:
|
||||
- name: Run python tests on V100
|
||||
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
|
||||
run: |
|
||||
# TODO[Superjomn]: Remove the forloop-unroll setting after pipeline pass works
|
||||
cd python/tests
|
||||
pytest test_gemm.py::test_gemm_no_scf_for_mmav1
|
||||
export TRITON_STATIC_LOOP_UNROLLING=1
|
||||
pytest test_gemm.py::test_gemm_for_mmav1
|
||||
|
||||
- name: Run CXX unittests
|
||||
run: |
|
||||
|
@@ -72,6 +72,11 @@ struct DotOpMmaV1ConversionHelper {
|
||||
|
||||
bool isARow = order[0] != 0;
|
||||
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
|
||||
// TODO[Superjomn]: Support the case when isAVec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with
|
||||
// different ld vector width.
|
||||
isAVec4 = true;
|
||||
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
@@ -98,6 +103,11 @@ struct DotOpMmaV1ConversionHelper {
|
||||
auto order = getOrder();
|
||||
bool isBRow = order[0] != 0;
|
||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
||||
// TODO[Superjomn]: Support the case when isBVec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with
|
||||
// different ld vector width.
|
||||
isBVec4 = true;
|
||||
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
||||
@@ -1455,7 +1465,6 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
|
||||
|
||||
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
bool isBRow = order[0] != 0;
|
||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
||||
|
@@ -3555,14 +3555,18 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter);
|
||||
auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter);
|
||||
|
||||
// initialize accumulators
|
||||
// Initialize accumulators with external values, the acc holds the accumulator
|
||||
// value that is shared between the MMA instructions inside a DotOp, we can
|
||||
// call the order of the values the accumulator-internal order.
|
||||
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
|
||||
size_t resSize = acc.size();
|
||||
|
||||
// The resVals holds the final result of the DotOp.
|
||||
// NOTE The current order of resVals is different from acc, we call it the
|
||||
// accumulator-external order. and
|
||||
SmallVector<Value> resVals(resSize);
|
||||
|
||||
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
|
||||
auto ha = has.at({m, k});
|
||||
auto hb = hbs.at({n, k});
|
||||
auto getIdx = [&](int m, int n) {
|
||||
std::vector<size_t> idx{{
|
||||
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
|
||||
(m * 2 + 0) + (n * 4 + 1) * numM,
|
||||
@@ -3573,8 +3577,29 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
(m * 2 + 1) + (n * 4 + 2) * numM, // row3
|
||||
(m * 2 + 1) + (n * 4 + 3) * numM,
|
||||
}};
|
||||
return idx;
|
||||
};
|
||||
|
||||
{ // convert the acc's value from accumuator-external order to
|
||||
// accumulator-internal order.
|
||||
SmallVector<Value> accInit(acc.size());
|
||||
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
auto idx = getIdx(m, n);
|
||||
for (unsigned i = 0; i < 8; ++i)
|
||||
accInit[idx[i]] = acc[(m * numN / 2 + n) * 8 + i];
|
||||
}
|
||||
|
||||
acc = accInit;
|
||||
}
|
||||
|
||||
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
|
||||
auto ha = has.at({m, k});
|
||||
auto hb = hbs.at({n, k});
|
||||
|
||||
PTXBuilder builder;
|
||||
auto idx = getIdx(m, n);
|
||||
|
||||
auto *resOprs = builder.newListOperand(8, "=f");
|
||||
auto *AOprs = builder.newListOperand({
|
||||
@@ -3606,8 +3631,6 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
for (unsigned i = 0; i < 8; i++) {
|
||||
Value elem = extract_val(f32_ty, res, getIntAttr(i));
|
||||
acc[idx[i]] = elem;
|
||||
// TODO[goostavz]: double confirm this when m/n/k = [32, 32, x] has been
|
||||
// verified before MMA
|
||||
resVals[(m * numN / 2 + n) * 8 + i] = elem;
|
||||
}
|
||||
};
|
||||
|
@@ -1382,10 +1382,11 @@ void init_triton_translation(py::module &m) {
|
||||
llvm::SMDiagnostic error;
|
||||
std::unique_ptr<llvm::Module> module =
|
||||
llvm::parseIR(buffer->getMemBufferRef(), error, context);
|
||||
if (!module)
|
||||
if (!module) {
|
||||
llvm::report_fatal_error(
|
||||
"failed to parse IR: " + error.getMessage() +
|
||||
"lineno: " + std::to_string(error.getLineNo()));
|
||||
}
|
||||
|
||||
// translate module to PTX
|
||||
auto ptxCode =
|
||||
|
@@ -295,18 +295,17 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
|
||||
|
||||
# NOTE this is useful only on Volta GPU.
|
||||
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
|
||||
(shape, num_warps, trans_a, trans_b)
|
||||
for shape in [
|
||||
[16, 16, 16],
|
||||
[16, 16, 32],
|
||||
[32, 16, 16],
|
||||
[32, 32, 32],
|
||||
[128, 16, 16],
|
||||
]
|
||||
for num_warps in [1]
|
||||
for trans_a in [False]
|
||||
for trans_b in [False]
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
|
||||
# Non-forloop
|
||||
[16, 16, 16, 1, 16, 16, 16, False, False],
|
||||
[16, 16, 32, 1, 16, 16, 32, False, False],
|
||||
[32, 16, 32, 1, 32, 16, 32, False, False],
|
||||
[32, 32, 32, 1, 32, 32, 32, False, False],
|
||||
[128, 32, 32, 1, 128, 32, 32, False, False],
|
||||
|
||||
# split-K
|
||||
[16, 16, 32, 1, 16, 16, 16, False, False],
|
||||
[64, 64, 128, 1, 64, 64, 32, False, False],
|
||||
])
|
||||
def test_gemm_no_scf_for_mmav1(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B)
|
||||
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
|
||||
|
@@ -583,7 +583,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
isinstance(step, triton.language.constexpr):
|
||||
sta_range = iterator(lb.value, ub.value, step.value)
|
||||
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
|
||||
if static_unrolling and len(range) <= 10:
|
||||
if static_unrolling and len(sta_range) <= 10:
|
||||
for i in sta_range:
|
||||
self.lscope[node.target.id] = triton.language.constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
|
Reference in New Issue
Block a user