[BACKEND] Add bf16 & tf32 mma supports (on A100) (#426)
This commit is contained in:
@@ -10,6 +10,7 @@ import torch
|
||||
from numpy.random import RandomState
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
import triton.language as tl
|
||||
from triton.code_gen import TensorWrapper, reinterpret
|
||||
|
||||
@@ -660,22 +661,26 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
|
||||
def test_dot(epilogue, device='cuda'):
|
||||
@pytest.mark.parametrize("epilogue, allow_tf32",
|
||||
[(epilogue, allow_tf32)
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
|
||||
for allow_tf32 in [True, False]])
|
||||
def test_dot(epilogue, allow_tf32, device='cuda'):
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xk,
|
||||
Y, stride_yk, stride_yn,
|
||||
Z, stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr):
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
ALLOW_TF32: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
off_k = tl.arange(0, BLOCK_K)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
|
||||
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
z = tl.dot(tl.load(Xs), tl.load(Ys))
|
||||
z = tl.dot(tl.load(Xs), tl.load(Ys), allow_tf32=ALLOW_TF32)
|
||||
if ADD_MATRIX:
|
||||
z += tl.load(Zs)
|
||||
if ADD_ROWS:
|
||||
@@ -690,6 +695,12 @@ def test_dot(epilogue, device='cuda'):
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((M, K), dtype_str='float32', rs=rs)
|
||||
y = numpy_random((K, N), dtype_str='float32', rs=rs)
|
||||
if allow_tf32:
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
x_tri = to_triton(x, device=device)
|
||||
y_tri = to_triton(y, device=device)
|
||||
# triton result
|
||||
@@ -703,7 +714,8 @@ def test_dot(epilogue, device='cuda'):
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
ADD_COLS=epilogue == 'add-cols')
|
||||
ADD_COLS=epilogue == 'add-cols',
|
||||
ALLOW_TF32=allow_tf32)
|
||||
# torch result
|
||||
z_ref = np.matmul(x, y)
|
||||
if epilogue == 'add-matrix':
|
||||
@@ -718,6 +730,8 @@ def test_dot(epilogue, device='cuda'):
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
|
||||
|
||||
def test_dot_without_load():
|
||||
|
Reference in New Issue
Block a user