fix 6/7 dot tests

This commit is contained in:
Michael Melesse
2022-11-01 14:18:06 +00:00
parent 4f3e2d6ed7
commit 4fb9d4904e
3 changed files with 15 additions and 7 deletions

View File

@@ -1067,7 +1067,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
[(epilogue, allow_tf32, dtype)
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for allow_tf32 in [True, False]
for dtype in ['float32', 'float16']
for dtype in ['float16']
if not (allow_tf32 and (dtype in ['float16']))])
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
if torch.version.hip is not None: