Repro swizzling bug
This commit is contained in:
@@ -1177,19 +1177,25 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
z_tri = to_triton(z, device=device)
|
||||
if epilogue == 'trans':
|
||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
COL_A=col_a, COL_B=col_b,
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
ADD_COLS=epilogue == 'add-cols',
|
||||
DO_SOFTMAX=epilogue == 'softmax',
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
ALLOW_TF32=allow_tf32,
|
||||
num_warps=num_warps)
|
||||
# pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
# y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
# w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
# z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
# COL_A=col_a, COL_B=col_b,
|
||||
# BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
# ADD_MATRIX=epilogue == 'add-matrix',
|
||||
# ADD_ROWS=epilogue == 'add-rows',
|
||||
# ADD_COLS=epilogue == 'add-cols',
|
||||
# DO_SOFTMAX=epilogue == 'softmax',
|
||||
# CHAIN_DOT=epilogue == 'chain-dot',
|
||||
# ALLOW_TF32=allow_tf32,
|
||||
# num_warps=num_warps)
|
||||
kernel = triton.compile("./chain-dot.ttgir", num_warps=num_warps)
|
||||
pgm = kernel[(1, 1, 1)](x_tri.data_ptr(), x_tri.stride(0),
|
||||
y_tri.data_ptr(), y_tri.stride(0),
|
||||
w_tri.data_ptr(), w_tri.stride(0),
|
||||
z_tri.data_ptr(), z_tri.stride(0))
|
||||
|
||||
# torch result
|
||||
if dtype == 'int8':
|
||||
z_ref = np.matmul(x.astype(np.float32),
|
||||
@@ -1217,15 +1223,15 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
# ptx = pgm.asm['ptx']
|
||||
# assert 'ld.global.v4' in ptx
|
||||
# assert 'st.global.v4' in ptx
|
||||
# if dtype == 'float32' and allow_tf32:
|
||||
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
# elif dtype == 'float32' and allow_tf32:
|
||||
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
# elif dtype == 'int8':
|
||||
# assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
def test_dot_without_load():
|
||||
|
Reference in New Issue
Block a user