[Triton-MLIR] tt.dot
operands now must have DotOperand layout; also added prefetch pass prototype (#712)
Co-authored-by: Jokeren <kerenzhou@openai.com> Co-authored-by: Phil Tillet <phil@openai.com> Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
@@ -171,63 +171,65 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
|
||||
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [
|
||||
[32, 32, 16, 4, 32, 32, 16],
|
||||
[32, 16, 16, 4, 32, 32, 16],
|
||||
[128, 8, 8, 4, 32, 32, 16],
|
||||
[127, 41, 43, 4, 32, 32, 16],
|
||||
])
|
||||
def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
pid_m = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K)
|
||||
b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N)
|
||||
a = tl.load(a_ptrs, a_mask)
|
||||
b = tl.load(b_ptrs, b_mask)
|
||||
# NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering
|
||||
accumulator += tl.dot(a, b, allow_tf32=False)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
offs_k += BLOCK_SIZE_K
|
||||
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, c_mask)
|
||||
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float32)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float32)
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
|
||||
matmul_kernel[grid](a, b, c,
|
||||
M, N, K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K)
|
||||
|
||||
golden = torch.matmul(a, b)
|
||||
torch.testing.assert_close(c, golden)
|
||||
# XXX(Keren): Temporarily disable this test until we have shared -> dot conversion implemented
|
||||
#@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [
|
||||
# [32, 32, 16, 4, 32, 32, 16],
|
||||
# [32, 16, 16, 4, 32, 32, 16],
|
||||
# [128, 8, 8, 4, 32, 32, 16],
|
||||
# [127, 41, 43, 4, 32, 32, 16],
|
||||
#])
|
||||
#def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||
# @triton.jit
|
||||
# def matmul_kernel(
|
||||
# a_ptr, b_ptr, c_ptr,
|
||||
# M, N, K,
|
||||
# stride_am, stride_ak,
|
||||
# stride_bk, stride_bn,
|
||||
# stride_cm, stride_cn,
|
||||
# BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
# ):
|
||||
# pid = tl.program_id(axis=0)
|
||||
# # num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
# pid_m = pid // num_pid_n
|
||||
# pid_n = pid % num_pid_n
|
||||
#
|
||||
# offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
# offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
# offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
# a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
# b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
#
|
||||
# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
# for k in range(0, K, BLOCK_SIZE_K):
|
||||
# a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K)
|
||||
# b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N)
|
||||
# a = tl.load(a_ptrs, a_mask)
|
||||
# b = tl.load(b_ptrs, b_mask)
|
||||
# # NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering
|
||||
# accumulator += tl.dot(a, b, allow_tf32=False)
|
||||
# a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
# b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
# offs_k += BLOCK_SIZE_K
|
||||
#
|
||||
# offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
# offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
# c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
|
||||
# c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
# tl.store(c_ptrs, accumulator, c_mask)
|
||||
#
|
||||
# a = torch.randn((M, K), device='cuda', dtype=torch.float32)
|
||||
# b = torch.randn((K, N), device='cuda', dtype=torch.float32)
|
||||
# c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
#
|
||||
# grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
|
||||
# matmul_kernel[grid](a, b, c,
|
||||
# M, N, K,
|
||||
# stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
# stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
# stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
# BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K)
|
||||
#
|
||||
# golden = torch.matmul(a, b)
|
||||
# torch.testing.assert_close(c, golden)
|
||||
#
|
||||
|
Reference in New Issue
Block a user