This commit is contained in:
Superjomn
2022-11-04 18:04:20 +08:00
parent db64477153
commit b39cc56f93

View File

@@ -148,46 +148,51 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
# Precession regression for FMADot is not done yet due to some issue on the optimizer failed to give a blocked layout to dot op. # Precession regression for FMADot is not done yet due to some issue on the optimizer failed to give a blocked layout to dot op.
# TODO[Superjomn]: Uncomment this test and continue to finish precession regression latter. # TODO[Superjomn]: Uncomment this test and continue to finish precession regression latter.
# @pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [ @pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [
# [128, 256, 128, 4, 128, 256, 32], [128, 256, 128, 4, 128, 256, 32],
# [256, 128, 64, 4, 256, 128, 16], [256, 128, 64, 4, 256, 128, 16],
# [128, 64, 128, 4, 128, 64, 32], [128, 64, 128, 4, 128, 64, 32],
# ]) ])
# def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
# @triton.jit @triton.jit
# def matmul_kernel( def matmul_kernel(
# a_ptr, b_ptr, c_ptr, a_ptr, b_ptr, c_ptr,
# stride_am, stride_ak, stride_am, stride_ak,
# stride_bk, stride_bn, stride_bk, stride_bn,
# stride_cm, stride_cn, stride_cm, stride_cn,
# K: tl.constexpr, K: tl.constexpr,
# BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
# ): ):
# offs_m = tl.arange(0, BLOCK_SIZE_M) offs_m = tl.arange(0, BLOCK_SIZE_M)
# offs_n = tl.arange(0, BLOCK_SIZE_N) offs_n = tl.arange(0, BLOCK_SIZE_N)
# offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
# a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
# b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# for k in range(0, K, BLOCK_SIZE_K): for k in range(0, K, BLOCK_SIZE_K):
# a = tl.load(a_ptrs) a = tl.load(a_ptrs)
# b = tl.load(b_ptrs) b = tl.load(b_ptrs)
# accumulator += tl.dot(a, b, allow_tf32=True) # NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering
# a_ptrs += BLOCK_SIZE_K * stride_ak accumulator += tl.dot(a, b, allow_tf32=False)
# b_ptrs += BLOCK_SIZE_K * stride_bk a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
# tl.store(c_ptrs, accumulator) tl.store(c_ptrs, accumulator)
# a = torch.randn((M, K), device='cuda', dtype=torch.float32) a = torch.randn((M, K), device='cuda', dtype=torch.float32)
# b = torch.randn((K, N), device='cuda', dtype=torch.float) b = torch.randn((K, N), device='cuda', dtype=torch.float)
# c = torch.empty((M, N), device=a.device, dtype=torch.float32) c = torch.empty((M, N), device=a.device, dtype=torch.float32)
# grid = lambda META: (1, ) grid = lambda META: (1, )
# matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
# stride_am=a.stride(0), stride_ak=a.stride(1), stride_am=a.stride(0), stride_ak=a.stride(1),
# stride_bk=b.stride(0), stride_bn=b.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1),
# stride_cm=c.stride(0), stride_cn=c.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1),
# K=a.shape[1], BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, K=a.shape[1], BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N,
# BLOCK_SIZE_K=block_K, num_warps=num_warps) BLOCK_SIZE_K=block_K, num_warps=num_warps)
# golden = torch.matmul(a, b) golden = torch.matmul(a, b)
# torch.testing.assert_close(c, golden) torch.testing.assert_close(c, golden)
#test_gemm_no_scf(*[64, 128, 128, 2])
test_gemm_fmadot(*[128, 64, 128, 4, 128, 64, 32])