diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index e32622005..64f60e260 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -732,6 +732,8 @@ def test_dot(epilogue, allow_tf32, device='cuda'): assert 'st.global.v4' in ptx if allow_tf32: assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx + else: + assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx def test_dot_without_load(): diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 1d413a0e6..514fbab7b 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -4,6 +4,7 @@ import pytest import torch import triton +import triton._C.libtriton.triton as _triton @pytest.mark.parametrize( @@ -48,7 +49,7 @@ import triton (128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE), (128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE), (128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE), - ] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] + ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] ], # n-stage *[ @@ -61,11 +62,16 @@ import triton # split-k (64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), (64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE), - ] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4] + ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4] ] ), ) def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE): + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 80 and DTYPE == "bfloat16": + pytest.skip("Only test bfloat16 on devices with sm >= 80") + if DTYPE == "bfloat16" and SPLIT_K != 1: + pytest.skip("bfloat16 matmuls don't allow split_k for now") torch.manual_seed(0) # nuke kernel decorators -- will set meta-parameters manually kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} @@ -81,7 +87,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, N = BLOCK_N if N is None else N K = BLOCK_K * SPLIT_K if K is None else K # allocate/transpose inputs - DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] + DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE] a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE) b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE) a = a.t() if AT else a diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 63e3d0aa0..960df4efc 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -838,7 +838,7 @@ class Autotuner: # prune configs pruned_configs = self.configs if self.prune_num_stages_by: - pruned_configs = self.prune_num_stages_by(self.configs) + pruned_configs = self.prune_num_stages_by(self.configs, self.nargs) if self.perf_model: top_k = self.configs_top_k if isinstance(top_k, float) and top_k <= 1.0: diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index d7af57406..9466b9ba7 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -87,7 +87,7 @@ def _kernel(A, B, C, M, N, K, acc += tl.dot(a, b) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk - acc = acc.to(tl.float16) + acc = acc.to(C.dtype.element_ty) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py index af4f3eed8..98a85bc85 100644 --- a/python/triton/ops/matmul_perf_model.py +++ b/python/triton/ops/matmul_perf_model.py @@ -78,12 +78,16 @@ def estimate_matmul_time( return total_time_ms -def prune_num_stages(configs): +def prune_num_stages(configs, named_args): backend = _triton.runtime.backend.CUDA device = torch.cuda.current_device() cc = _triton.runtime.cc(backend, device) # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + # Some dtypes do not allow atomic_add + if named_args['A'].dtype == torch.bfloat16: + configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1] + # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) configs_map = {} for config in configs: diff --git a/python/triton/testing.py b/python/triton/testing.py index 310e754ed..199226ea1 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -65,8 +65,12 @@ def mask_tensor(x, mask, block, value=0): def assert_almost_equal(x, y, decimal=2, err_msg=''): import numpy.testing as npt if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() x = x.cpu().detach().numpy() if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() y = y.cpu().detach().numpy() npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)