[BACKEND] Better bf16 support (#588)

This commit is contained in:
daadaada
2022-07-20 12:22:37 +08:00
committed by GitHub
parent 86cab58d89
commit 9b2bc88d11
6 changed files with 180 additions and 62 deletions

View File

@@ -2,18 +2,22 @@ import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize("M, N, dtype, mode",
[
(M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']
for dtype in ['bfloat16', 'float16', 'float32']
for mode in ['forward', 'backward']
]
)
def test_op(M, N, dtype, mode):
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and dtype == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype]
# create inputs
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')