[Triton-MLIR][Backend] Fix mma<v2> int8 precision error (#850)
Fix mma.16816 s8 precision error Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
This commit is contained in:
@@ -55,6 +55,33 @@ def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
|
||||
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [
|
||||
[64, 128, 128, 1],
|
||||
[128, 128, 128, 4],
|
||||
[16, 8, 32, 1],
|
||||
[32, 16, 64, 2],
|
||||
[32, 16, 64, 4],
|
||||
])
|
||||
def test_gemm_no_scf_int8(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
|
||||
a = torch.randint(-5, 5, (SIZE_M, SIZE_K), device='cuda', dtype=torch.int8)
|
||||
b = torch.randint(-5, 5, (SIZE_K, SIZE_N), device='cuda', dtype=torch.int8)
|
||||
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.int32)
|
||||
|
||||
grid = lambda META: (1, )
|
||||
matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
M=SIZE_M, N=SIZE_N, K=SIZE_K,
|
||||
num_warps=NUM_WARPS)
|
||||
|
||||
aa = a.cpu()
|
||||
bb = b.cpu()
|
||||
golden = torch.matmul(aa.float(), bb.float()).int()
|
||||
torch.set_printoptions(profile="full")
|
||||
torch.testing.assert_close(c.cpu(), golden, check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
@@ -80,8 +107,6 @@ def matmul_kernel(
|
||||
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
||||
tl.store(c_ptrs, accumulator)
|
||||
|
||||
# TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment
|
||||
|
||||
|
||||
def get_variant_golden(a, b):
|
||||
SIZE_M = a.shape[0]
|
||||
|
Reference in New Issue
Block a user