[Backend] Hacky fix of missing barrier in ConvertLayout blocked->shared (#803)
Barrier should be set by a separate pass, but it seems like there may be some bugs
This commit is contained in:
@@ -83,6 +83,23 @@ def matmul_kernel(
|
||||
# TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment
|
||||
|
||||
|
||||
def get_variant_golden(a, b):
|
||||
SIZE_M = a.shape[0]
|
||||
SIZE_K = a.shape[1]
|
||||
SIZE_N = b.shape[1]
|
||||
assert a.shape[1] == b.shape[0]
|
||||
zero_M_K = torch.zeros((SIZE_M, SIZE_K)).cuda()
|
||||
zero_3M_K = torch.zeros((3 * SIZE_M, SIZE_K)).cuda()
|
||||
zero_K_N = torch.zeros((SIZE_K, SIZE_N)).cuda()
|
||||
zero_3K_N = torch.zeros((3 * SIZE_K, SIZE_N)).cuda()
|
||||
a_padded = torch.cat((a, zero_M_K, zero_M_K), 0)
|
||||
a_padded = torch.cat((a_padded, zero_3M_K, zero_3M_K), 1)
|
||||
b_padded = torch.cat((b, zero_K_N, zero_K_N), 0)
|
||||
b_padded = torch.cat((b_padded, zero_3K_N, zero_3K_N), 1)
|
||||
c_padded = torch.matmul(a_padded, b_padded)
|
||||
return c_padded[:SIZE_M, :SIZE_N]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
|
||||
# Non-forloop
|
||||
[64, 32, 64, 4, 64, 32, 64],
|
||||
@@ -94,8 +111,8 @@ def matmul_kernel(
|
||||
[32, 64, 128, 4, 32, 64, 32],
|
||||
[32, 128, 256, 4, 32, 128, 64],
|
||||
[64, 128, 64, 4, 64, 128, 32],
|
||||
[128, 128, 64, 4, 128, 128, 32],
|
||||
[64, 64, 128, 4, 64, 64, 32],
|
||||
[128, 128, 64, 4, 128, 128, 32],
|
||||
[128, 128, 128, 4, 128, 128, 32],
|
||||
[128, 128, 256, 4, 128, 128, 64],
|
||||
[128, 256, 128, 4, 128, 256, 32],
|
||||
@@ -115,5 +132,15 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
num_warps=NUM_WARPS)
|
||||
golden = torch.matmul(a, b)
|
||||
|
||||
# It's not easy to get a proper error threshold in different size
|
||||
# Here the gemm calculation is padded to a different size in order to get
|
||||
# a variant version of the golden result. And the error between golden and
|
||||
# golden_variant provide reference on selecting the proper rtol / atol.
|
||||
golden_variant = get_variant_golden(a, b)
|
||||
golden_diff = golden - golden_variant
|
||||
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
|
||||
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
|
||||
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
||||
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
|
||||
|
Reference in New Issue
Block a user