[TESTS] Added bfloat16 tests (#430)
This commit is contained in:
		@@ -732,6 +732,8 @@ def test_dot(epilogue, allow_tf32, device='cuda'):
 | 
				
			|||||||
    assert 'st.global.v4' in ptx
 | 
					    assert 'st.global.v4' in ptx
 | 
				
			||||||
    if allow_tf32:
 | 
					    if allow_tf32:
 | 
				
			||||||
        assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
 | 
					        assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_dot_without_load():
 | 
					def test_dot_without_load():
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,6 +4,7 @@ import pytest
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import triton
 | 
					import triton
 | 
				
			||||||
 | 
					import triton._C.libtriton.triton as _triton
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize(
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
@@ -48,7 +49,7 @@ import triton
 | 
				
			|||||||
                (128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
 | 
					                (128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
 | 
				
			||||||
                (128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
 | 
					                (128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
 | 
				
			||||||
                (128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
 | 
					                (128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
 | 
				
			||||||
            ] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
 | 
					            ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
        # n-stage
 | 
					        # n-stage
 | 
				
			||||||
        *[
 | 
					        *[
 | 
				
			||||||
@@ -61,11 +62,16 @@ import triton
 | 
				
			|||||||
                # split-k
 | 
					                # split-k
 | 
				
			||||||
                (64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
 | 
					                (64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
 | 
				
			||||||
                (64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
 | 
					                (64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
 | 
				
			||||||
            ] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
 | 
					            ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
    ),
 | 
					    ),
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
 | 
					def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
 | 
				
			||||||
 | 
					    cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
 | 
				
			||||||
 | 
					    if cc < 80 and DTYPE == "bfloat16":
 | 
				
			||||||
 | 
					        pytest.skip("Only test bfloat16 on devices with sm >= 80")
 | 
				
			||||||
 | 
					    if DTYPE == "bfloat16" and SPLIT_K != 1:
 | 
				
			||||||
 | 
					        pytest.skip("bfloat16 matmuls don't allow split_k for now")
 | 
				
			||||||
    torch.manual_seed(0)
 | 
					    torch.manual_seed(0)
 | 
				
			||||||
    # nuke kernel decorators -- will set meta-parameters manually
 | 
					    # nuke kernel decorators -- will set meta-parameters manually
 | 
				
			||||||
    kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
 | 
					    kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
 | 
				
			||||||
@@ -81,7 +87,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
 | 
				
			|||||||
    N = BLOCK_N if N is None else N
 | 
					    N = BLOCK_N if N is None else N
 | 
				
			||||||
    K = BLOCK_K * SPLIT_K if K is None else K
 | 
					    K = BLOCK_K * SPLIT_K if K is None else K
 | 
				
			||||||
    # allocate/transpose inputs
 | 
					    # allocate/transpose inputs
 | 
				
			||||||
    DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
 | 
					    DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE]
 | 
				
			||||||
    a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
 | 
					    a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
 | 
				
			||||||
    b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
 | 
					    b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
 | 
				
			||||||
    a = a.t() if AT else a
 | 
					    a = a.t() if AT else a
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -838,7 +838,7 @@ class Autotuner:
 | 
				
			|||||||
                # prune configs
 | 
					                # prune configs
 | 
				
			||||||
                pruned_configs = self.configs
 | 
					                pruned_configs = self.configs
 | 
				
			||||||
                if self.prune_num_stages_by:
 | 
					                if self.prune_num_stages_by:
 | 
				
			||||||
                    pruned_configs = self.prune_num_stages_by(self.configs)
 | 
					                    pruned_configs = self.prune_num_stages_by(self.configs, self.nargs)
 | 
				
			||||||
                if self.perf_model:
 | 
					                if self.perf_model:
 | 
				
			||||||
                    top_k = self.configs_top_k
 | 
					                    top_k = self.configs_top_k
 | 
				
			||||||
                    if isinstance(top_k, float) and top_k <= 1.0:
 | 
					                    if isinstance(top_k, float) and top_k <= 1.0:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -87,7 +87,7 @@ def _kernel(A, B, C, M, N, K,
 | 
				
			|||||||
        acc += tl.dot(a, b)
 | 
					        acc += tl.dot(a, b)
 | 
				
			||||||
        A += BLOCK_K * SPLIT_K * stride_ak
 | 
					        A += BLOCK_K * SPLIT_K * stride_ak
 | 
				
			||||||
        B += BLOCK_K * SPLIT_K * stride_bk
 | 
					        B += BLOCK_K * SPLIT_K * stride_bk
 | 
				
			||||||
    acc = acc.to(tl.float16)
 | 
					    acc = acc.to(C.dtype.element_ty)
 | 
				
			||||||
    # rematerialize rm and rn to save registers
 | 
					    # rematerialize rm and rn to save registers
 | 
				
			||||||
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
 | 
					    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
 | 
				
			||||||
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
 | 
					    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -78,12 +78,16 @@ def estimate_matmul_time(
 | 
				
			|||||||
    return total_time_ms
 | 
					    return total_time_ms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def prune_num_stages(configs):
 | 
					def prune_num_stages(configs, named_args):
 | 
				
			||||||
    backend = _triton.runtime.backend.CUDA
 | 
					    backend = _triton.runtime.backend.CUDA
 | 
				
			||||||
    device = torch.cuda.current_device()
 | 
					    device = torch.cuda.current_device()
 | 
				
			||||||
    cc = _triton.runtime.cc(backend, device)
 | 
					    cc = _triton.runtime.cc(backend, device)
 | 
				
			||||||
    # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
 | 
					    # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Some dtypes do not allow atomic_add
 | 
				
			||||||
 | 
					    if named_args['A'].dtype == torch.bfloat16:
 | 
				
			||||||
 | 
					        configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
 | 
					    # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
 | 
				
			||||||
    configs_map = {}
 | 
					    configs_map = {}
 | 
				
			||||||
    for config in configs:
 | 
					    for config in configs:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -65,8 +65,12 @@ def mask_tensor(x, mask, block, value=0):
 | 
				
			|||||||
def assert_almost_equal(x, y, decimal=2, err_msg=''):
 | 
					def assert_almost_equal(x, y, decimal=2, err_msg=''):
 | 
				
			||||||
    import numpy.testing as npt
 | 
					    import numpy.testing as npt
 | 
				
			||||||
    if isinstance(x, torch.Tensor):
 | 
					    if isinstance(x, torch.Tensor):
 | 
				
			||||||
 | 
					        if x.dtype == torch.bfloat16:
 | 
				
			||||||
 | 
					            x = x.float()
 | 
				
			||||||
        x = x.cpu().detach().numpy()
 | 
					        x = x.cpu().detach().numpy()
 | 
				
			||||||
    if isinstance(y, torch.Tensor):
 | 
					    if isinstance(y, torch.Tensor):
 | 
				
			||||||
 | 
					        if y.dtype == torch.bfloat16:
 | 
				
			||||||
 | 
					            y = y.float()
 | 
				
			||||||
        y = y.cpu().detach().numpy()
 | 
					        y = y.cpu().detach().numpy()
 | 
				
			||||||
    npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
 | 
					    npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user