[Triton-MLIR][Backend] Fix mma<v2> int8 precision error (#850)

Fix mma.16816 s8 precision error

Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
This commit is contained in:
Yan Chunwei
2022-11-09 12:23:43 +08:00
committed by GitHub
parent e517b58d59
commit de5b84c476
4 changed files with 146 additions and 42 deletions

View File

@@ -55,6 +55,33 @@ def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [
[64, 128, 128, 1],
[128, 128, 128, 4],
[16, 8, 32, 1],
[32, 16, 64, 2],
[32, 16, 64, 4],
])
def test_gemm_no_scf_int8(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
a = torch.randint(-5, 5, (SIZE_M, SIZE_K), device='cuda', dtype=torch.int8)
b = torch.randint(-5, 5, (SIZE_K, SIZE_N), device='cuda', dtype=torch.int8)
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.int32)
grid = lambda META: (1, )
matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
M=SIZE_M, N=SIZE_N, K=SIZE_K,
num_warps=NUM_WARPS)
aa = a.cpu()
bb = b.cpu()
golden = torch.matmul(aa.float(), bb.float()).int()
torch.set_printoptions(profile="full")
torch.testing.assert_close(c.cpu(), golden, check_dtype=False)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
@@ -80,8 +107,6 @@ def matmul_kernel(
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, accumulator)
# TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment
def get_variant_golden(a, b):
SIZE_M = a.shape[0]