1. Improve pipline's comment 2. Decompose insert_slice_async when load vector size is not supported 3. Add a test that could fail our gemm code Copy my comments here: There's a knob that may cause performance regression when decomposition has been performed. We should remove this knob once we have thorough analysis on async wait. Currently, we decompose `insert_slice_async` into `load` and `insert_slice` without knowing which `async_wait` is responsible for the `insert_slice_async`. To guarantee correctness, we blindly set the `async_wait` to wait for all async ops if any `insert_slice_async` has been decomposed. There are two options to improve this: 1. We can perform a dataflow analysis to find the `async_wait` that is responsible for the `insert_slice_async` in the backend. 4. We can modify the pipeline to perform the decomposition before the `async_wait` is inserted. However, it is also risky because we don't know the correct vectorized shape yet in the pipeline pass. Making the pipeline pass aware of the vectorization could introduce additional dependencies on the AxisInfoAnalysis and the Coalesce analysis.
285 lines
11 KiB
Python
285 lines
11 KiB
Python
import pytest
|
|
import torch
|
|
from torch.testing import assert_close
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.jit
|
|
def matmul_no_scf_kernel(
|
|
a_ptr, b_ptr, c_ptr,
|
|
stride_am, stride_ak,
|
|
stride_bk, stride_bn,
|
|
stride_cm, stride_cn,
|
|
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr
|
|
):
|
|
offs_m = tl.arange(0, M)
|
|
offs_n = tl.arange(0, N)
|
|
offs_k = tl.arange(0, K)
|
|
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
|
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
|
a = tl.load(a_ptrs)
|
|
b = tl.load(b_ptrs)
|
|
|
|
c = tl.dot(a, b)
|
|
|
|
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
|
tl.store(c_ptrs, c)
|
|
|
|
|
|
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
|
|
(shape, num_warps, trans_a, trans_b)
|
|
for shape in [
|
|
[128, 256, 32],
|
|
[256, 128, 16],
|
|
[128, 16, 32],
|
|
[32, 128, 64],
|
|
[128, 128, 64],
|
|
[64, 128, 128],
|
|
]
|
|
for num_warps in [2, 4]
|
|
for trans_a in [False, True]
|
|
for trans_b in [False, True]
|
|
])
|
|
def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
|
SIZE_M, SIZE_N, SIZE_K = SHAPE
|
|
if (TRANS_A):
|
|
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
|
|
else:
|
|
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
|
|
|
if (TRANS_B):
|
|
b = torch.randn((SIZE_N, SIZE_K), device='cuda', dtype=torch.float16).T
|
|
else:
|
|
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
|
|
|
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
|
grid = lambda META: (1, )
|
|
matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
|
stride_am=a.stride(0), stride_ak=a.stride(1),
|
|
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
|
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
|
M=SIZE_M, N=SIZE_N, K=SIZE_K,
|
|
num_warps=NUM_WARPS)
|
|
golden = torch.matmul(a, b)
|
|
torch.set_printoptions(profile="full")
|
|
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
|
|
|
|
|
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
|
|
(shape, num_warps, trans_a, trans_b)
|
|
for shape in [
|
|
[64, 128, 128],
|
|
[128, 128, 128],
|
|
[16, 8, 32],
|
|
[32, 16, 64],
|
|
[32, 16, 64],
|
|
]
|
|
for num_warps in [1, 2, 4]
|
|
for trans_a in [False, True]
|
|
for trans_b in [False, True]
|
|
])
|
|
def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
|
SIZE_M, SIZE_N, SIZE_K = SHAPE
|
|
|
|
if (TRANS_A):
|
|
a = torch.randint(-5, 5, (SIZE_K, SIZE_M), device='cuda', dtype=torch.int8).T
|
|
else:
|
|
a = torch.randint(-5, 5, (SIZE_M, SIZE_K), device='cuda', dtype=torch.int8)
|
|
|
|
if (TRANS_B):
|
|
b = torch.randint(-5, 5, (SIZE_N, SIZE_K), device='cuda', dtype=torch.int8).T
|
|
else:
|
|
b = torch.randint(-5, 5, (SIZE_K, SIZE_N), device='cuda', dtype=torch.int8)
|
|
|
|
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.int32)
|
|
|
|
grid = lambda META: (1, )
|
|
matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
|
stride_am=a.stride(0), stride_ak=a.stride(1),
|
|
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
|
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
|
M=SIZE_M, N=SIZE_N, K=SIZE_K,
|
|
num_warps=NUM_WARPS)
|
|
|
|
aa = a.cpu()
|
|
bb = b.cpu()
|
|
golden = torch.matmul(aa.float(), bb.float()).int()
|
|
torch.set_printoptions(profile="full")
|
|
torch.testing.assert_close(c.cpu(), golden, check_dtype=False)
|
|
|
|
|
|
@triton.jit
|
|
def matmul_kernel(
|
|
a_ptr, b_ptr, c_ptr,
|
|
stride_am, stride_ak,
|
|
stride_bk, stride_bn,
|
|
stride_cm, stride_cn,
|
|
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
|
|
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
|
):
|
|
offs_m = tl.arange(0, BLOCK_SIZE_M)
|
|
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
|
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
for k in range(0, K, BLOCK_SIZE_K):
|
|
a = tl.load(a_ptrs)
|
|
b = tl.load(b_ptrs)
|
|
accumulator += tl.dot(a, b)
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
|
tl.store(c_ptrs, accumulator)
|
|
|
|
|
|
def get_variant_golden(a, b):
|
|
SIZE_M = a.shape[0]
|
|
SIZE_K = a.shape[1]
|
|
SIZE_N = b.shape[1]
|
|
assert a.shape[1] == b.shape[0]
|
|
zero_M_K = torch.zeros((SIZE_M, SIZE_K)).cuda()
|
|
zero_3M_K = torch.zeros((3 * SIZE_M, SIZE_K)).cuda()
|
|
zero_K_N = torch.zeros((SIZE_K, SIZE_N)).cuda()
|
|
zero_3K_N = torch.zeros((3 * SIZE_K, SIZE_N)).cuda()
|
|
a_padded = torch.cat((a, zero_M_K, zero_M_K), 0)
|
|
a_padded = torch.cat((a_padded, zero_3M_K, zero_3M_K), 1)
|
|
b_padded = torch.cat((b, zero_K_N, zero_K_N), 0)
|
|
b_padded = torch.cat((b_padded, zero_3K_N, zero_3K_N), 1)
|
|
c_padded = torch.matmul(a_padded, b_padded)
|
|
return c_padded[:SIZE_M, :SIZE_N]
|
|
|
|
# It's not easy to get a proper error threshold in different size
|
|
# Here the gemm calculation is padded to a different size in order to get
|
|
# a variant version of the golden result. And the error between golden and
|
|
# golden_variant provide reference on selecting the proper rtol / atol.
|
|
|
|
|
|
def get_proper_err(a, b, golden):
|
|
golden_variant = get_variant_golden(a, b)
|
|
golden_diff = golden - golden_variant
|
|
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
|
|
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
|
|
return (golden_abs_err, golden_rel_err)
|
|
|
|
|
|
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
|
|
# Non-forloop
|
|
[64, 32, 64, 4, 64, 32, 64, False, False],
|
|
[128, 64, 128, 4, 128, 64, 128, False, False],
|
|
[16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue
|
|
# K-Forloop
|
|
#[16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads
|
|
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
|
|
[16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k
|
|
[64, 32, 128, 4, 64, 32, 64, False, False],
|
|
[128, 16, 128, 4, 128, 16, 32, False, False],
|
|
[32, 16, 128, 4, 32, 16, 32, False, False],
|
|
[32, 64, 128, 4, 32, 64, 32, False, False],
|
|
[32, 128, 256, 4, 32, 128, 64, False, False],
|
|
[64, 128, 64, 4, 64, 128, 32, False, False],
|
|
[64, 64, 128, 4, 64, 64, 32, False, False],
|
|
[128, 128, 64, 4, 128, 128, 32, False, False],
|
|
[128, 128, 128, 4, 128, 128, 32, False, False],
|
|
[128, 128, 256, 4, 128, 128, 64, False, False],
|
|
[128, 256, 128, 4, 128, 256, 32, False, False],
|
|
[256, 128, 64, 4, 256, 128, 16, False, False],
|
|
[128, 64, 128, 4, 128, 64, 32, False, False],
|
|
[16, 16, 64, 4, 16, 16, 16, False, False],
|
|
[32, 32, 64, 4, 32, 32, 32, False, False],
|
|
# trans
|
|
[128, 64, 128, 4, 128, 64, 32, True, False],
|
|
[128, 64, 128, 4, 128, 64, 32, False, True],
|
|
])
|
|
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
|
if (TRANS_A):
|
|
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
|
|
else:
|
|
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
|
|
|
if (TRANS_B):
|
|
b = torch.randn((SIZE_N, SIZE_K), device='cuda', dtype=torch.float16).T
|
|
else:
|
|
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
|
|
|
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
|
grid = lambda META: (1, )
|
|
matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
|
stride_am=a.stride(0), stride_ak=a.stride(1),
|
|
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
|
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
|
M=a.shape[0], N=b.shape[1], K=a.shape[1],
|
|
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
|
num_warps=NUM_WARPS)
|
|
golden = torch.matmul(a, b)
|
|
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
|
|
torch.set_printoptions(profile="full")
|
|
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
|
|
|
|
|
|
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [
|
|
[32, 32, 16, 4, 32, 32, 16],
|
|
[32, 16, 16, 4, 32, 32, 16],
|
|
[128, 8, 8, 4, 32, 32, 16],
|
|
# TODO[Superjomn]: fix it later
|
|
# [127, 41, 43, 4, 32, 32, 16],
|
|
])
|
|
def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
|
@triton.jit
|
|
def matmul_kernel(
|
|
a_ptr, b_ptr, c_ptr,
|
|
M, N, K,
|
|
stride_am, stride_ak,
|
|
stride_bk, stride_bn,
|
|
stride_cm, stride_cn,
|
|
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
pid_m = pid // num_pid_n
|
|
pid_n = pid % num_pid_n
|
|
|
|
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
for k in range(0, K, BLOCK_SIZE_K):
|
|
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K)
|
|
b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N)
|
|
a = tl.load(a_ptrs, a_mask)
|
|
b = tl.load(b_ptrs, b_mask)
|
|
# NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering
|
|
accumulator += tl.dot(a, b, allow_tf32=False)
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
offs_k += BLOCK_SIZE_K
|
|
|
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
|
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, accumulator, c_mask)
|
|
|
|
a = torch.randn((M, K), device='cuda', dtype=torch.float32)
|
|
b = torch.randn((K, N), device='cuda', dtype=torch.float32)
|
|
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
|
|
|
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
|
|
matmul_kernel[grid](a, b, c,
|
|
M, N, K,
|
|
stride_am=a.stride(0), stride_ak=a.stride(1),
|
|
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
|
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
|
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K)
|
|
|
|
golden = torch.matmul(a, b)
|
|
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
|
|
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|