[Triton-MLIR][BACKEND] make MMAv1 splitk works (#960)
This commit is contained in:
@@ -1382,10 +1382,11 @@ void init_triton_translation(py::module &m) {
|
||||
llvm::SMDiagnostic error;
|
||||
std::unique_ptr<llvm::Module> module =
|
||||
llvm::parseIR(buffer->getMemBufferRef(), error, context);
|
||||
if (!module)
|
||||
if (!module) {
|
||||
llvm::report_fatal_error(
|
||||
"failed to parse IR: " + error.getMessage() +
|
||||
"lineno: " + std::to_string(error.getLineNo()));
|
||||
}
|
||||
|
||||
// translate module to PTX
|
||||
auto ptxCode =
|
||||
|
@@ -295,18 +295,17 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
|
||||
|
||||
# NOTE this is useful only on Volta GPU.
|
||||
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
|
||||
(shape, num_warps, trans_a, trans_b)
|
||||
for shape in [
|
||||
[16, 16, 16],
|
||||
[16, 16, 32],
|
||||
[32, 16, 16],
|
||||
[32, 32, 32],
|
||||
[128, 16, 16],
|
||||
]
|
||||
for num_warps in [1]
|
||||
for trans_a in [False]
|
||||
for trans_b in [False]
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
|
||||
# Non-forloop
|
||||
[16, 16, 16, 1, 16, 16, 16, False, False],
|
||||
[16, 16, 32, 1, 16, 16, 32, False, False],
|
||||
[32, 16, 32, 1, 32, 16, 32, False, False],
|
||||
[32, 32, 32, 1, 32, 32, 32, False, False],
|
||||
[128, 32, 32, 1, 128, 32, 32, False, False],
|
||||
|
||||
# split-K
|
||||
[16, 16, 32, 1, 16, 16, 16, False, False],
|
||||
[64, 64, 128, 1, 64, 64, 32, False, False],
|
||||
])
|
||||
def test_gemm_no_scf_for_mmav1(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B)
|
||||
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
|
||||
|
@@ -583,7 +583,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
isinstance(step, triton.language.constexpr):
|
||||
sta_range = iterator(lb.value, ub.value, step.value)
|
||||
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
|
||||
if static_unrolling and len(range) <= 10:
|
||||
if static_unrolling and len(sta_range) <= 10:
|
||||
for i in sta_range:
|
||||
self.lscope[node.target.id] = triton.language.constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
|
Reference in New Issue
Block a user