[Triton-MLIR][BACKEND] Make mmav1 works on basic cases (#944)

TODO:

- Add more cases
- Currently, we just set vec to 4 to make the basic cases pass

Issue:

- the vec in shared layout is different compared to master branch
- when vec=1, it encounters CUDA misalignment error, it doesn't work in
master branch as well
- when setting vec to the value identical to master branch, the MMA
works
This commit is contained in:
Yan Chunwei
2022-12-06 10:57:08 +08:00
committed by GitHub
parent 189491727a
commit e419781978
8 changed files with 134 additions and 100 deletions

View File

@@ -1383,6 +1383,11 @@ void init_triton_translation(py::module &m) {
llvm::SMDiagnostic error;
std::unique_ptr<llvm::Module> module =
llvm::parseIR(buffer->getMemBufferRef(), error, context);
if (!module)
llvm::report_fatal_error(
"failed to parse IR: " + error.getMessage() +
"lineno: " + std::to_string(error.getLineNo()));
// translate module to PTX
auto ptxCode =
triton::translateLLVMIRToPTX(*module, capability, version);

View File

@@ -172,7 +172,7 @@ def get_proper_err(a, b, golden):
[128, 64, 128, 4, 128, 64, 128, False, False],
[16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue
# K-Forloop
#[16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads
# [16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
[16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k
[64, 32, 128, 4, 64, 32, 64, False, False],
@@ -292,3 +292,21 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
torch.testing.assert_close(c, golden, rtol=max(1e-2, 1.5 * golden_rel_err), atol=max(1e-2, 1.5 * golden_abs_err))
else:
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
# NOTE this is useful only on Volta GPU.
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
(shape, num_warps, trans_a, trans_b)
for shape in [
[16, 16, 16],
[16, 16, 32],
[32, 16, 16],
[32, 32, 32],
[128, 16, 16],
]
for num_warps in [1]
for trans_a in [False]
for trans_b in [False]
])
def test_gemm_no_scf_for_mmav1(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B)