[BACKEND] Added Int8 mma (#440)

This commit is contained in:
daadaada
2022-01-28 01:12:44 +08:00
committed by GitHub
parent 3a23c1dd33
commit 59d371c6eb
11 changed files with 232 additions and 115 deletions

View File

@@ -661,11 +661,20 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# ---------------
@pytest.mark.parametrize("epilogue, allow_tf32",
[(epilogue, allow_tf32)
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
[(epilogue, allow_tf32, dtype)
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
for allow_tf32 in [True, False]])
def test_dot(epilogue, allow_tf32, device='cuda'):
for allow_tf32 in [True, False]
for dtype in ['float32', 'int8']
if not (allow_tf32 and (dtype == 'int8'))])
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xk,
@@ -693,18 +702,15 @@ def test_dot(epilogue, allow_tf32, device='cuda'):
# input
M, N, K = 64, 64, 32
rs = RandomState(17)
x = numpy_random((M, K), dtype_str='float32', rs=rs)
y = numpy_random((K, N), dtype_str='float32', rs=rs)
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
if allow_tf32:
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80:
pytest.skip("Only test tf32 on devices with sm >= 80")
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
# triton result
z = numpy_random((M, N), dtype_str='float32', rs=rs)
z = numpy_random((M, N), dtype_str=dtype, rs=rs)
z_tri = to_triton(z, device=device)
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
@@ -732,8 +738,10 @@ def test_dot(epilogue, allow_tf32, device='cuda'):
assert 'st.global.v4' in ptx
if allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
else:
elif dtype == 'float32':
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
def test_dot_without_load():