[FRONTEND][BACKEND] Clean-up transpositions (#953)
This commit is contained in:
@@ -1126,7 +1126,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
if CHAIN_DOT:
|
||||
# tl.store(Zs, z)
|
||||
# tl.debug_barrier()
|
||||
z = tl.dot(tl.trans(z.to(tl.float16)), tl.load(Ws))
|
||||
z = tl.dot(z.to(tl.float16), tl.load(Ws))
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
@@ -1173,7 +1173,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
denom = np.sum(num, axis=-1, keepdims=True)
|
||||
z_ref = num / denom
|
||||
if epilogue == 'chain-dot':
|
||||
z_ref = np.matmul(z_ref.T, w)
|
||||
z_ref = np.matmul(z_ref, w)
|
||||
# compare
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
Reference in New Issue
Block a user