diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index df337988c..dba85694e 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -148,46 +148,51 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO # Precession regression for FMADot is not done yet due to some issue on the optimizer failed to give a blocked layout to dot op. # TODO[Superjomn]: Uncomment this test and continue to finish precession regression latter. -# @pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [ -# [128, 256, 128, 4, 128, 256, 32], -# [256, 128, 64, 4, 256, 128, 16], -# [128, 64, 128, 4, 128, 64, 32], -# ]) -# def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): -# @triton.jit -# def matmul_kernel( -# a_ptr, b_ptr, c_ptr, -# stride_am, stride_ak, -# stride_bk, stride_bn, -# stride_cm, stride_cn, -# K: tl.constexpr, -# BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -# ): -# offs_m = tl.arange(0, BLOCK_SIZE_M) -# offs_n = tl.arange(0, BLOCK_SIZE_N) -# offs_k = tl.arange(0, BLOCK_SIZE_K) -# a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak -# b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn -# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) -# for k in range(0, K, BLOCK_SIZE_K): -# a = tl.load(a_ptrs) -# b = tl.load(b_ptrs) -# accumulator += tl.dot(a, b, allow_tf32=True) -# a_ptrs += BLOCK_SIZE_K * stride_ak -# b_ptrs += BLOCK_SIZE_K * stride_bk +@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [ + [128, 256, 128, 4, 128, 256, 32], + [256, 128, 64, 4, 256, 128, 16], + [128, 64, 128, 4, 128, 64, 32], +]) +def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): + @triton.jit + def matmul_kernel( + a_ptr, b_ptr, c_ptr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + ): + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + # NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering + accumulator += tl.dot(a, b, allow_tf32=False) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk -# c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn -# tl.store(c_ptrs, accumulator) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, accumulator) -# a = torch.randn((M, K), device='cuda', dtype=torch.float32) -# b = torch.randn((K, N), device='cuda', dtype=torch.float) -# c = torch.empty((M, N), device=a.device, dtype=torch.float32) -# grid = lambda META: (1, ) -# matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, -# stride_am=a.stride(0), stride_ak=a.stride(1), -# stride_bk=b.stride(0), stride_bn=b.stride(1), -# stride_cm=c.stride(0), stride_cn=c.stride(1), -# K=a.shape[1], BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, -# BLOCK_SIZE_K=block_K, num_warps=num_warps) -# golden = torch.matmul(a, b) -# torch.testing.assert_close(c, golden) + a = torch.randn((M, K), device='cuda', dtype=torch.float32) + b = torch.randn((K, N), device='cuda', dtype=torch.float) + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + grid = lambda META: (1, ) + matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, + stride_am=a.stride(0), stride_ak=a.stride(1), + stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), + K=a.shape[1], BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, + BLOCK_SIZE_K=block_K, num_warps=num_warps) + golden = torch.matmul(a, b) + torch.testing.assert_close(c, golden) + + +#test_gemm_no_scf(*[64, 128, 128, 2]) +test_gemm_fmadot(*[128, 64, 128, 4, 128, 64, 32])