[FRONTEND][BACKEND] Clean-up transpositions (#953)

This commit is contained in:
Philippe Tillet
2022-12-06 09:32:13 -08:00
committed by GitHub
parent 16e973edf2
commit 532e10cf87
12 changed files with 31 additions and 53 deletions

View File

@@ -1126,7 +1126,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
if CHAIN_DOT:
# tl.store(Zs, z)
# tl.debug_barrier()
z = tl.dot(tl.trans(z.to(tl.float16)), tl.load(Ws))
z = tl.dot(z.to(tl.float16), tl.load(Ws))
tl.store(Zs, z)
# input
rs = RandomState(17)
@@ -1173,7 +1173,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
denom = np.sum(num, axis=-1, keepdims=True)
z_ref = num / denom
if epilogue == 'chain-dot':
z_ref = np.matmul(z_ref.T, w)
z_ref = np.matmul(z_ref, w)
# compare
# print(z_ref[:,0], z_tri[:,0])
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)