[Triton-MLIR][BACKEND] Make mmav1 works on basic cases (#944)
TODO: - Add more cases - Currently, we just set vec to 4 to make the basic cases pass Issue: - the vec in shared layout is different compared to master branch - when vec=1, it encounters CUDA misalignment error, it doesn't work in master branch as well - when setting vec to the value identical to master branch, the MMA works
This commit is contained in:
@@ -1383,6 +1383,11 @@ void init_triton_translation(py::module &m) {
|
||||
llvm::SMDiagnostic error;
|
||||
std::unique_ptr<llvm::Module> module =
|
||||
llvm::parseIR(buffer->getMemBufferRef(), error, context);
|
||||
if (!module)
|
||||
llvm::report_fatal_error(
|
||||
"failed to parse IR: " + error.getMessage() +
|
||||
"lineno: " + std::to_string(error.getLineNo()));
|
||||
|
||||
// translate module to PTX
|
||||
auto ptxCode =
|
||||
triton::translateLLVMIRToPTX(*module, capability, version);
|
||||
|
@@ -172,7 +172,7 @@ def get_proper_err(a, b, golden):
|
||||
[128, 64, 128, 4, 128, 64, 128, False, False],
|
||||
[16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue
|
||||
# K-Forloop
|
||||
#[16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads
|
||||
# [16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads
|
||||
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
|
||||
[16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k
|
||||
[64, 32, 128, 4, 64, 32, 64, False, False],
|
||||
@@ -292,3 +292,21 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
torch.testing.assert_close(c, golden, rtol=max(1e-2, 1.5 * golden_rel_err), atol=max(1e-2, 1.5 * golden_abs_err))
|
||||
else:
|
||||
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
||||
|
||||
|
||||
# NOTE this is useful only on Volta GPU.
|
||||
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
|
||||
(shape, num_warps, trans_a, trans_b)
|
||||
for shape in [
|
||||
[16, 16, 16],
|
||||
[16, 16, 32],
|
||||
[32, 16, 16],
|
||||
[32, 32, 32],
|
||||
[128, 16, 16],
|
||||
]
|
||||
for num_warps in [1]
|
||||
for trans_a in [False]
|
||||
for trans_b in [False]
|
||||
])
|
||||
def test_gemm_no_scf_for_mmav1(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B)
|
||||
|
Reference in New Issue
Block a user