update test_dot to use float 32
This commit is contained in:
@@ -1067,15 +1067,18 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
|||||||
[(epilogue, allow_tf32, dtype)
|
[(epilogue, allow_tf32, dtype)
|
||||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||||
for allow_tf32 in [True, False]
|
for allow_tf32 in [True, False]
|
||||||
for dtype in ['float16']
|
for dtype in ['float32', 'float16']
|
||||||
if not (allow_tf32 and (dtype in ['float16']))])
|
if not (allow_tf32 and (dtype in ['float16']))])
|
||||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
if torch.version.hip is not None:
|
||||||
if cc < 80:
|
pass
|
||||||
if dtype == 'int8':
|
else:
|
||||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||||
elif dtype == 'float32' and allow_tf32:
|
if cc < 80:
|
||||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
if dtype == 'int8':
|
||||||
|
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||||
|
elif dtype == 'float32' and allow_tf32:
|
||||||
|
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||||
|
|
||||||
M, N, K = 128, 128, 64
|
M, N, K = 128, 128, 64
|
||||||
num_warps = 8
|
num_warps = 8
|
||||||
@@ -1170,15 +1173,18 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
|||||||
# print(z_ref[:,0], z_tri[:,0])
|
# print(z_ref[:,0], z_tri[:,0])
|
||||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||||
# make sure ld/st are vectorized
|
# make sure ld/st are vectorized
|
||||||
ptx = pgm.asm['ptx']
|
if torch.version.hip is not None:
|
||||||
assert 'ld.global.v4' in ptx
|
pass
|
||||||
assert 'st.global.v4' in ptx
|
else:
|
||||||
if allow_tf32:
|
ptx = pgm.asm['ptx']
|
||||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
assert 'ld.global.v4' in ptx
|
||||||
elif dtype == 'float32':
|
assert 'st.global.v4' in ptx
|
||||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
if allow_tf32:
|
||||||
elif dtype == 'int8':
|
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
elif dtype == 'float32':
|
||||||
|
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||||
|
elif dtype == 'int8':
|
||||||
|
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||||
|
|
||||||
|
|
||||||
def test_dot_without_load():
|
def test_dot_without_load():
|
||||||
|
@@ -7,7 +7,7 @@ sudo apt install gdb -y
|
|||||||
|
|
||||||
gdb -ex "set pagination off" \
|
gdb -ex "set pagination off" \
|
||||||
-ex "file python" \
|
-ex "file python" \
|
||||||
-ex 'run -m pytest --capture=tee-sys --verbose "python/test/unit/language/test_core.py::test_bin_op[int32-uint32-+]"' \
|
-ex 'run -m pytest --capture=tee-sys --verbose "python/test/unit/language/test_core.py::test_dot"' \
|
||||||
-ex "backtrace" \
|
-ex "backtrace" \
|
||||||
-ex "set confirm off" \
|
-ex "set confirm off" \
|
||||||
-ex "q" \
|
-ex "q" \
|
||||||
|
@@ -26,7 +26,7 @@ rm -rf /tmp/triton
|
|||||||
# python python/test/test_empty.py
|
# python python/test/test_empty.py
|
||||||
# -ex 'ignore 1 472' \
|
# -ex 'ignore 1 472' \
|
||||||
|
|
||||||
pytest -rfs --verbose python/test/unit/language/test_core.py 2>&1 | tee /dockerx/triton/test_core.log
|
# pytest -rfs --verbose python/test/unit/language/test_core.py 2>&1 | tee /dockerx/triton/test_core.log
|
||||||
# pytest --verbose python/test/unit/language/test_core.py::test_empty_kernel[float32] 2>&1 | tee /dockerx/triton/test_empty_kernel.log
|
# pytest --verbose python/test/unit/language/test_core.py::test_empty_kernel[float32] 2>&1 | tee /dockerx/triton/test_empty_kernel.log
|
||||||
# pytest --verbose python/test/unit/language/test_core.py::test_bin_op[int32-uint32-+] 2>&1 | tee /dockerx/triton/test_bin_op.log
|
# pytest --verbose python/test/unit/language/test_core.py::test_bin_op[int32-uint32-+] 2>&1 | tee /dockerx/triton/test_bin_op.log
|
||||||
# pytest --verbose python/test/unit/language/test_core.py::test_atomic_rmw 2>&1 | tee /dockerx/triton/test_atomic_rmw.log
|
# pytest --verbose python/test/unit/language/test_core.py::test_atomic_rmw 2>&1 | tee /dockerx/triton/test_atomic_rmw.log
|
||||||
@@ -55,7 +55,7 @@ pytest -rfs --verbose python/test/unit/language/test_core.py 2>&1 | tee /dockerx
|
|||||||
# pytest --capture=tee-sys --verbose python/test/unit/language/test_core.py::test_num_programs[float32]
|
# pytest --capture=tee-sys --verbose python/test/unit/language/test_core.py::test_num_programs[float32]
|
||||||
# pytest --verbose python/test/unit/language/test_core.py::test_unary_op
|
# pytest --verbose python/test/unit/language/test_core.py::test_unary_op
|
||||||
# pytest --verbose python/test/unit/language/test_core.py::test_bin_op
|
# pytest --verbose python/test/unit/language/test_core.py::test_bin_op
|
||||||
# pytest --verbose "python/test/unit/language/test_core.py::test_dot"
|
pytest --verbose "python/test/unit/language/test_core.py::test_dot"
|
||||||
# pytest --verbose python/test/unit/language/test_core.py::test_cast
|
# pytest --verbose python/test/unit/language/test_core.py::test_cast
|
||||||
# pytest --verbose python/test/unit/language/test_core.py::test_reduce1d
|
# pytest --verbose python/test/unit/language/test_core.py::test_reduce1d
|
||||||
# pytest --verbose python/test/unit/language/test_core.py::test_reduce2d
|
# pytest --verbose python/test/unit/language/test_core.py::test_reduce2d
|
||||||
|
Reference in New Issue
Block a user