[Triton-MLIR][BACKEND] Support $c from mma layout in dot (#798)
This PR does 1. Support the case where $c holding a mma layout, this should be useful in forloop in k-axis in GEMM 2. Fix the `unrealized_conversion_cast` in ConvertLayout[shared->dot_op] Known issue 1. There is some IO conflict in GEMM with a k-forloop, it is temporarily solved by [adding a barrier](https://github.com/openai/triton/pull/798/files#diff-8a9a5a7f4a025fb1299af29d190d5626bd9000406d3ea47c49679272d3d6abe9R3028) in dot conversion, but we are still working on it, will get a more generic fix for it in the following PR. 2. The parallel pass will result in a buggy instruction result type ```mlir %1049 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.commit_group ;", "" : () -> !llvm.void %1050 = builtin.unrealized_conversion_cast %1049 : !llvm.void to !llvm.ptr<f16, 3> ``` So we temporarily disable it.
This commit is contained in:
@@ -35,6 +35,9 @@ def matmul_no_scf_kernel(
|
||||
[256, 128, 16, 4],
|
||||
[128, 16, 32, 4],
|
||||
[32, 128, 64, 4],
|
||||
[128, 128, 64, 4],
|
||||
[64, 128, 128, 4],
|
||||
[64, 128, 128, 2],
|
||||
])
|
||||
def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
|
||||
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
||||
@@ -78,24 +81,39 @@ def matmul_kernel(
|
||||
tl.store(c_ptrs, accumulator)
|
||||
|
||||
# TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment
|
||||
# @pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
|
||||
# [128, 256, 128, 4, 128, 256, 32],
|
||||
# # [256, 128, 64, 4, 256, 128, 16],
|
||||
# # [128, 16, 128, 4, 128, 16, 32],
|
||||
# # [32, 128, 256, 4, 32, 128, 64],
|
||||
# ])
|
||||
# def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):
|
||||
# a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
||||
# b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
||||
# c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
||||
# grid = lambda META: (1, )
|
||||
# matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
# stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
# stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
# stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
# M=a.shape[0], N=b.shape[1], K=a.shape[1],
|
||||
# BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
# num_warps=NUM_WARPS)
|
||||
# golden = torch.matmul(a, b)
|
||||
# torch.set_printoptions(profile="full")
|
||||
# assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
|
||||
# Non-forloop
|
||||
[64, 32, 64, 4, 64, 32, 64],
|
||||
[128, 64, 128, 4, 128, 64, 128],
|
||||
# K-Forloop
|
||||
[64, 32, 128, 4, 64, 32, 64],
|
||||
[128, 16, 128, 4, 128, 16, 32],
|
||||
[32, 16, 128, 4, 32, 16, 32],
|
||||
[32, 64, 128, 4, 32, 64, 32],
|
||||
[32, 128, 256, 4, 32, 128, 64],
|
||||
[64, 128, 64, 4, 64, 128, 32],
|
||||
[128, 128, 64, 4, 128, 128, 32],
|
||||
[64, 64, 128, 4, 64, 64, 32],
|
||||
[128, 128, 128, 4, 128, 128, 32],
|
||||
[128, 128, 256, 4, 128, 128, 64],
|
||||
[128, 256, 128, 4, 128, 256, 32],
|
||||
[256, 128, 64, 4, 256, 128, 16],
|
||||
[128, 64, 128, 4, 128, 64, 32],
|
||||
])
|
||||
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):
|
||||
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
||||
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
||||
grid = lambda META: (1, )
|
||||
matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
M=a.shape[0], N=b.shape[1], K=a.shape[1],
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
num_warps=NUM_WARPS)
|
||||
golden = torch.matmul(a, b)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
||||
|
Reference in New Issue
Block a user