[BACKEND] Support dot op when the output is mma encoding and allowtf32 is true (#937)

This commit is contained in:
Keren Zhou
2022-12-03 11:14:12 -08:00
committed by GitHub
parent 8edfe813a5
commit f2fcaeabf3
5 changed files with 105 additions and 72 deletions

View File

@@ -220,14 +220,17 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [
[32, 32, 16, 4, 32, 32, 16],
[32, 16, 16, 4, 32, 32, 16],
[128, 8, 8, 4, 32, 32, 16],
# TODO[Superjomn]: fix it later
# [127, 41, 43, 4, 32, 32, 16],
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K,allow_tf32', [
[32, 32, 16, 4, 32, 32, 16, False],
[32, 32, 16, 4, 32, 32, 16, True],
[32, 16, 16, 4, 32, 32, 16, False],
[32, 16, 16, 4, 32, 32, 16, True],
[127, 41, 43, 4, 32, 32, 16, False],
[127, 41, 43, 4, 32, 32, 16, True],
[128, 8, 8, 4, 32, 32, 16, False],
[128, 8, 8, 4, 32, 32, 16, True]
])
def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
@@ -236,6 +239,7 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
ALLOW_TF32: tl.constexpr
):
pid = tl.program_id(axis=0)
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
@@ -253,10 +257,9 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
for k in range(0, K, BLOCK_SIZE_K):
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K)
b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N)
a = tl.load(a_ptrs, a_mask)
b = tl.load(b_ptrs, b_mask)
# NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering
accumulator += tl.dot(a, b, allow_tf32=False)
a = tl.load(a_ptrs, a_mask, other=0.0)
b = tl.load(b_ptrs, b_mask, other=0.0)
accumulator += tl.dot(a, b, allow_tf32=ALLOW_TF32)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
offs_k += BLOCK_SIZE_K
@@ -267,6 +270,9 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, c_mask)
# Configure the pytorch counterpart
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
a = torch.randn((M, K), device='cuda', dtype=torch.float32)
b = torch.randn((K, N), device='cuda', dtype=torch.float32)
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
@@ -277,8 +283,12 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K)
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K, ALLOW_TF32=allow_tf32)
golden = torch.matmul(a, b)
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
if allow_tf32:
# TF32 is not accurate enough
torch.testing.assert_close(c, golden, rtol=max(1e-2, 1.5 * golden_rel_err), atol=max(1e-2, 1.5 * golden_abs_err))
else:
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))