Repro swizzling bug

This commit is contained in:
Phil Tillet
2023-01-02 23:44:25 -08:00
parent 0e8590f1c9
commit 08366b2d59
3 changed files with 83 additions and 24 deletions

View File

@@ -1177,19 +1177,25 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
z_tri = to_triton(z, device=device)
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
COL_A=col_a, COL_B=col_b,
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols',
DO_SOFTMAX=epilogue == 'softmax',
CHAIN_DOT=epilogue == 'chain-dot',
ALLOW_TF32=allow_tf32,
num_warps=num_warps)
# pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
# y_tri, y_tri.stride(0), y_tri.stride(1),
# w_tri, w_tri.stride(0), w_tri.stride(1),
# z_tri, z_tri.stride(0), z_tri.stride(1),
# COL_A=col_a, COL_B=col_b,
# BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
# ADD_MATRIX=epilogue == 'add-matrix',
# ADD_ROWS=epilogue == 'add-rows',
# ADD_COLS=epilogue == 'add-cols',
# DO_SOFTMAX=epilogue == 'softmax',
# CHAIN_DOT=epilogue == 'chain-dot',
# ALLOW_TF32=allow_tf32,
# num_warps=num_warps)
kernel = triton.compile("./chain-dot.ttgir", num_warps=num_warps)
pgm = kernel[(1, 1, 1)](x_tri.data_ptr(), x_tri.stride(0),
y_tri.data_ptr(), y_tri.stride(0),
w_tri.data_ptr(), w_tri.stride(0),
z_tri.data_ptr(), z_tri.stride(0))
# torch result
if dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32),
@@ -1217,15 +1223,15 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
# ptx = pgm.asm['ptx']
# assert 'ld.global.v4' in ptx
# assert 'st.global.v4' in ptx
# if dtype == 'float32' and allow_tf32:
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
# elif dtype == 'float32' and allow_tf32:
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
# elif dtype == 'int8':
# assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
def test_dot_without_load():