[BACKEND] Add bf16 & tf32 mma supports (on A100) (#426)

This commit is contained in:
daadaada
2022-01-12 02:20:31 +08:00
committed by GitHub
parent efdabe6073
commit 94a2e10fe5
17 changed files with 717 additions and 263 deletions

View File

@@ -10,6 +10,7 @@ import torch
from numpy.random import RandomState
import triton
import triton._C.libtriton.triton as _triton
import triton.language as tl
from triton.code_gen import TensorWrapper, reinterpret
@@ -660,22 +661,26 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# ---------------
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, device='cuda'):
@pytest.mark.parametrize("epilogue, allow_tf32",
[(epilogue, allow_tf32)
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
for allow_tf32 in [True, False]])
def test_dot(epilogue, allow_tf32, device='cuda'):
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xk,
Y, stride_yk, stride_yn,
Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr):
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
ALLOW_TF32: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_k = tl.arange(0, BLOCK_K)
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
z = tl.dot(tl.load(Xs), tl.load(Ys))
z = tl.dot(tl.load(Xs), tl.load(Ys), allow_tf32=ALLOW_TF32)
if ADD_MATRIX:
z += tl.load(Zs)
if ADD_ROWS:
@@ -690,6 +695,12 @@ def test_dot(epilogue, device='cuda'):
rs = RandomState(17)
x = numpy_random((M, K), dtype_str='float32', rs=rs)
y = numpy_random((K, N), dtype_str='float32', rs=rs)
if allow_tf32:
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80:
pytest.skip("Only test tf32 on devices with sm >= 80")
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
# triton result
@@ -703,7 +714,8 @@ def test_dot(epilogue, device='cuda'):
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols')
ADD_COLS=epilogue == 'add-cols',
ALLOW_TF32=allow_tf32)
# torch result
z_ref = np.matmul(x, y)
if epilogue == 'add-matrix':
@@ -718,6 +730,8 @@ def test_dot(epilogue, device='cuda'):
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
def test_dot_without_load():

View File

@@ -524,18 +524,18 @@ def reshape(input, shape, _builder=None):
@builtin
def dot(input, other, _builder=None):
def dot(input, other, allow_tf32=True, _builder=None):
"""
Returns the matrix product of two blocks.
The two blocks must be two dimensionals and have compatible inner dimensions.
:param input: The first block to be multiplied.
:type input: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
:type input: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
:param other: The second block to be multiplied.
:type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
:type other: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
"""
return frontend.dot(input, other, _builder)
return frontend.dot(input, other, allow_tf32, _builder)
# -----------------------