[TESTS] Added bfloat16 tests (#430)

This commit is contained in:
daadaada
2022-01-14 15:38:32 +08:00
committed by GitHub
parent 4c94359199
commit 2a944ded53
6 changed files with 22 additions and 6 deletions

View File

@@ -732,6 +732,8 @@ def test_dot(epilogue, allow_tf32, device='cuda'):
assert 'st.global.v4' in ptx
if allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
else:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
def test_dot_without_load():

View File

@@ -4,6 +4,7 @@ import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize(
@@ -48,7 +49,7 @@ import triton
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
],
# n-stage
*[
@@ -61,11 +62,16 @@ import triton
# split-k
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
]
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and DTYPE == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
if DTYPE == "bfloat16" and SPLIT_K != 1:
pytest.skip("bfloat16 matmuls don't allow split_k for now")
torch.manual_seed(0)
# nuke kernel decorators -- will set meta-parameters manually
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
@@ -81,7 +87,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
N = BLOCK_N if N is None else N
K = BLOCK_K * SPLIT_K if K is None else K
# allocate/transpose inputs
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE]
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a

View File

@@ -838,7 +838,7 @@ class Autotuner:
# prune configs
pruned_configs = self.configs
if self.prune_num_stages_by:
pruned_configs = self.prune_num_stages_by(self.configs)
pruned_configs = self.prune_num_stages_by(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:

View File

@@ -87,7 +87,7 @@ def _kernel(A, B, C, M, N, K,
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to(tl.float16)
acc = acc.to(C.dtype.element_ty)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

View File

@@ -78,12 +78,16 @@ def estimate_matmul_time(
return total_time_ms
def prune_num_stages(configs):
def prune_num_stages(configs, named_args):
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
cc = _triton.runtime.cc(backend, device)
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
# Some dtypes do not allow atomic_add
if named_args['A'].dtype == torch.bfloat16:
configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1]
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
configs_map = {}
for config in configs:

View File

@@ -65,8 +65,12 @@ def mask_tensor(x, mask, block, value=0):
def assert_almost_equal(x, y, decimal=2, err_msg=''):
import numpy.testing as npt
if isinstance(x, torch.Tensor):
if x.dtype == torch.bfloat16:
x = x.float()
x = x.cpu().detach().numpy()
if isinstance(y, torch.Tensor):
if y.dtype == torch.bfloat16:
y = y.float()
y = y.cpu().detach().numpy()
npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)