[FRONTEND][BACKEND] Added trans instruction; made flash attention bwd pass work (#943)

This commit is contained in:
Philippe Tillet
2022-12-03 09:58:24 -08:00
committed by GitHub
parent 4d64589b22
commit 8edfe813a5
12 changed files with 310 additions and 143 deletions

View File

@@ -667,7 +667,6 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
tl.atomic_add(Z + off1, z)
rs = RandomState(17)
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
print(x)
# reference result
z_ref = np.sum(x, axis=axis, keepdims=False)
# triton result
@@ -1067,122 +1066,126 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# # ---------------
# @pytest.mark.parametrize("epilogue, allow_tf32, dtype",
# [(epilogue, allow_tf32, dtype)
# for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
# for allow_tf32 in [True, False]
# for dtype in ['float16']
# if not (allow_tf32 and (dtype in ['float16']))])
# def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
# cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
# if cc < 80:
# if dtype == 'int8':
# pytest.skip("Only test int8 on devices with sm >= 80")
# elif dtype == 'float32' and allow_tf32:
# pytest.skip("Only test tf32 on devices with sm >= 80")
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
[(epilogue, allow_tf32, dtype)
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for allow_tf32 in [True, False]
for dtype in ['float16']
if not (allow_tf32 and (dtype in ['float16']))])
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
capability = torch.cuda.get_device_capability()
if capability[0] < 80:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
# M, N, K = 128, 128, 64
# num_warps = 8
# trans_a, trans_b = False, False
M, N, K = 64, 64, 64
num_warps = 4
trans_a, trans_b = False, False
# # triton kernel
# @triton.jit
# def kernel(X, stride_xm, stride_xk,
# Y, stride_yk, stride_yn,
# W, stride_wn, stride_wl,
# Z, stride_zm, stride_zn,
# BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
# ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
# ALLOW_TF32: tl.constexpr,
# DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
# TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
# off_m = tl.arange(0, BLOCK_M)
# off_n = tl.arange(0, BLOCK_N)
# off_l = tl.arange(0, BLOCK_N)
# off_k = tl.arange(0, BLOCK_K)
# Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
# Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
# Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
# Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
# z = tl.dot(tl.load(Xs), tl.load(Ys), trans_a=TRANS_A, trans_b=TRANS_B, allow_tf32=ALLOW_TF32)
# if ADD_MATRIX:
# z += tl.load(Zs)
# if ADD_ROWS:
# ZRs = Z + off_m * stride_zm
# z += tl.load(ZRs)[:, None]
# if ADD_COLS:
# ZCs = Z + off_n * stride_zn
# z += tl.load(ZCs)[None, :]
# if DO_SOFTMAX:
# max = tl.max(z, 1)
# z = z - max[:, None]
# num = tl.exp(z)
# den = tl.sum(num, 1)
# z = num / den[:, None]
# if CHAIN_DOT:
# # tl.store(Zs, z)
# # tl.debug_barrier()
# z = tl.dot(z.to(tl.float16), tl.load(Ws), trans_a=TRANS_A)
# tl.store(Zs, z)
# # input
# rs = RandomState(17)
# x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
# y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
# w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
# if allow_tf32:
# x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
# y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
# w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
# x_tri = to_triton(x, device=device)
# y_tri = to_triton(y, device=device)
# w_tri = to_triton(w, device=device)
# # triton result
# z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
# z_tri = to_triton(z, device=device)
# if epilogue == 'trans':
# z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
# pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
# y_tri, y_tri.stride(0), y_tri.stride(1),
# w_tri, w_tri.stride(0), w_tri.stride(1),
# z_tri, z_tri.stride(0), z_tri.stride(1),
# TRANS_A=trans_a, TRANS_B=trans_b,
# BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
# ADD_MATRIX=epilogue == 'add-matrix',
# ADD_ROWS=epilogue == 'add-rows',
# ADD_COLS=epilogue == 'add-cols',
# DO_SOFTMAX=epilogue == 'softmax',
# CHAIN_DOT=epilogue == 'chain-dot',
# ALLOW_TF32=allow_tf32,
# num_warps=num_warps)
# # torch result
# x_ref = x.T if trans_a else x
# y_ref = y.T if trans_b else y
# z_ref = np.matmul(x_ref, y_ref)
# if epilogue == 'add-matrix':
# z_ref += z
# if epilogue == 'add-rows':
# z_ref += z[:, 0][:, None]
# if epilogue == 'add-cols':
# z_ref += z[0, :][None, :]
# if epilogue == 'softmax':
# num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True))
# denom = np.sum(num, axis=-1, keepdims=True)
# z_ref = num / denom
# if epilogue == 'chain-dot':
# z_ref = np.matmul(z_ref.T if trans_a else z_ref, w)
# # compare
# # print(z_ref[:,0], z_tri[:,0])
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# # make sure ld/st are vectorized
# ptx = pgm.asm['ptx']
# assert 'ld.global.v4' in ptx
# assert 'st.global.v4' in ptx
# if allow_tf32:
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
# elif dtype == 'float32':
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
# elif dtype == 'int8':
# assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xk,
Y, stride_yk, stride_yn,
W, stride_wn, stride_wl,
Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
ALLOW_TF32: tl.constexpr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_l = tl.arange(0, BLOCK_N)
off_k = tl.arange(0, BLOCK_K)
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
x = tl.load(Xs)
y = tl.load(Ys)
x = tl.trans(x) if TRANS_A else x
y = tl.trans(y) if TRANS_B else y
z = tl.dot(x, y, allow_tf32=ALLOW_TF32)
if ADD_MATRIX:
z += tl.load(Zs)
if ADD_ROWS:
ZRs = Z + off_m * stride_zm
z += tl.load(ZRs)[:, None]
if ADD_COLS:
ZCs = Z + off_n * stride_zn
z += tl.load(ZCs)[None, :]
if DO_SOFTMAX:
max = tl.max(z, 1)
z = z - max[:, None]
num = tl.exp(z)
den = tl.sum(num, 1)
z = num / den[:, None]
if CHAIN_DOT:
# tl.store(Zs, z)
# tl.debug_barrier()
z = tl.dot(tl.trans(z.to(tl.float16)), tl.load(Ws))
tl.store(Zs, z)
# input
rs = RandomState(17)
x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
if allow_tf32:
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
w_tri = to_triton(w, device=device)
# triton result
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
z_tri = to_triton(z, device=device)
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
TRANS_A=trans_a, TRANS_B=trans_b,
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols',
DO_SOFTMAX=epilogue == 'softmax',
CHAIN_DOT=epilogue == 'chain-dot',
ALLOW_TF32=allow_tf32,
num_warps=num_warps)
# torch result
x_ref = x.T if trans_a else x
y_ref = y.T if trans_b else y
z_ref = np.matmul(x_ref, y_ref)
if epilogue == 'add-matrix':
z_ref += z
if epilogue == 'add-rows':
z_ref += z[:, 0][:, None]
if epilogue == 'add-cols':
z_ref += z[0, :][None, :]
if epilogue == 'softmax':
num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True))
denom = np.sum(num, axis=-1, keepdims=True)
z_ref = num / denom
if epilogue == 'chain-dot':
z_ref = np.matmul(z_ref.T, w)
# compare
# print(z_ref[:,0], z_tri[:,0])
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32':
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
# def test_dot_without_load():