[BACKEND] Added Int8 mma (#440)
This commit is contained in:
@@ -661,11 +661,20 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("epilogue, allow_tf32",
|
||||
[(epilogue, allow_tf32)
|
||||
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
|
||||
[(epilogue, allow_tf32, dtype)
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
|
||||
for allow_tf32 in [True, False]])
|
||||
def test_dot(epilogue, allow_tf32, device='cuda'):
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float32', 'int8']
|
||||
if not (allow_tf32 and (dtype == 'int8'))])
|
||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80:
|
||||
if dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xk,
|
||||
@@ -693,18 +702,15 @@ def test_dot(epilogue, allow_tf32, device='cuda'):
|
||||
# input
|
||||
M, N, K = 64, 64, 32
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((M, K), dtype_str='float32', rs=rs)
|
||||
y = numpy_random((K, N), dtype_str='float32', rs=rs)
|
||||
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
|
||||
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
|
||||
if allow_tf32:
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
x_tri = to_triton(x, device=device)
|
||||
y_tri = to_triton(y, device=device)
|
||||
# triton result
|
||||
z = numpy_random((M, N), dtype_str='float32', rs=rs)
|
||||
z = numpy_random((M, N), dtype_str=dtype, rs=rs)
|
||||
z_tri = to_triton(z, device=device)
|
||||
if epilogue == 'trans':
|
||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
@@ -732,8 +738,10 @@ def test_dot(epilogue, allow_tf32, device='cuda'):
|
||||
assert 'st.global.v4' in ptx
|
||||
if allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
else:
|
||||
elif dtype == 'float32':
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
def test_dot_without_load():
|
||||
|
Reference in New Issue
Block a user