Repro swizzling bug

This commit is contained in:
Phil Tillet
2023-01-02 23:44:25 -08:00
parent 0e8590f1c9
commit 08366b2d59
3 changed files with 83 additions and 24 deletions

55
python/chain-dot.ttgir Normal file
View File

@@ -0,0 +1,55 @@
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}>
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func public @kernel_0d1d2c3d4d5c6d7d8c9d10d11c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma0>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma1>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked0>
%3 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
%4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x64xi32, #blocked0>
%6 = tt.broadcast %5 : (tensor<1x64xi32, #blocked0>) -> tensor<64x64xi32, #blocked0>
%7 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
%8 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
%9 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked0>
%10 = tt.splat %arg4 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
%11 = tt.splat %arg7 : (i32) -> tensor<64x1xi32, #blocked0>
%12 = tt.splat %arg6 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
%13 = arith.muli %1, %2 : tensor<64x1xi32, #blocked0>
%14 = tt.addptr %3, %13 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
%15 = tt.broadcast %14 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x64x!tt.ptr<f16>, #blocked0>
%16 = tt.addptr %15, %6 : tensor<64x64x!tt.ptr<f16>, #blocked0>, tensor<64x64xi32, #blocked0>
%17 = tt.load %16 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked0>
%18 = arith.muli %1, %7 : tensor<64x1xi32, #blocked0>
%19 = tt.addptr %8, %18 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
%20 = tt.broadcast %19 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x64x!tt.ptr<f16>, #blocked0>
%21 = tt.addptr %20, %6 : tensor<64x64x!tt.ptr<f16>, #blocked0>, tensor<64x64xi32, #blocked0>
%22 = tt.load %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked0>
%23 = triton_gpu.convert_layout %17 : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%24 = triton_gpu.convert_layout %22 : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%25 = tt.dot %23, %24, %cst {allowTF32 = false} : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<64x64xf32, #mma0>
%27 = arith.muli %1, %9 : tensor<64x1xi32, #blocked0>
%28 = tt.addptr %10, %27 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
%29 = tt.broadcast %28 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x64x!tt.ptr<f16>, #blocked0>
%30 = tt.addptr %29, %6 : tensor<64x64x!tt.ptr<f16>, #blocked0>, tensor<64x64xi32, #blocked0>
%31 = tt.load %30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked0>
%32 = arith.truncf %25 : tensor<64x64xf32, #mma0> to tensor<64x64xf16, #mma0>
%133 = triton_gpu.convert_layout %32 : (tensor<64x64xf16, #mma0>) -> tensor<64x64xf16, #shared>
%33 = triton_gpu.convert_layout %133 : (tensor<64x64xf16, #shared>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%34 = triton_gpu.convert_layout %31 : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%35 = tt.dot %33, %34, %cst_0 {allowTF32 = true} : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<64x64xf32, #mma1>
%36 = triton_gpu.convert_layout %35 : (tensor<64x64xf32, #mma1>) -> tensor<64x64xf32, #blocked0>
%37 = arith.muli %1, %11 : tensor<64x1xi32, #blocked0>
%38 = tt.addptr %12, %37 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
%39 = tt.broadcast %38 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x64x!tt.ptr<f16>, #blocked0>
%40 = tt.addptr %39, %6 : tensor<64x64x!tt.ptr<f16>, #blocked0>, tensor<64x64xi32, #blocked0>
%41 = arith.truncf %36 : tensor<64x64xf32, #blocked0> to tensor<64x64xf16, #blocked0>
tt.store %40, %41 : tensor<64x64xf16, #blocked0>
return
}
}

View File

@@ -1177,19 +1177,25 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
z_tri = to_triton(z, device=device)
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
COL_A=col_a, COL_B=col_b,
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols',
DO_SOFTMAX=epilogue == 'softmax',
CHAIN_DOT=epilogue == 'chain-dot',
ALLOW_TF32=allow_tf32,
num_warps=num_warps)
# pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
# y_tri, y_tri.stride(0), y_tri.stride(1),
# w_tri, w_tri.stride(0), w_tri.stride(1),
# z_tri, z_tri.stride(0), z_tri.stride(1),
# COL_A=col_a, COL_B=col_b,
# BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
# ADD_MATRIX=epilogue == 'add-matrix',
# ADD_ROWS=epilogue == 'add-rows',
# ADD_COLS=epilogue == 'add-cols',
# DO_SOFTMAX=epilogue == 'softmax',
# CHAIN_DOT=epilogue == 'chain-dot',
# ALLOW_TF32=allow_tf32,
# num_warps=num_warps)
kernel = triton.compile("./chain-dot.ttgir", num_warps=num_warps)
pgm = kernel[(1, 1, 1)](x_tri.data_ptr(), x_tri.stride(0),
y_tri.data_ptr(), y_tri.stride(0),
w_tri.data_ptr(), w_tri.stride(0),
z_tri.data_ptr(), z_tri.stride(0))
# torch result
if dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32),
@@ -1217,15 +1223,15 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
# ptx = pgm.asm['ptx']
# assert 'ld.global.v4' in ptx
# assert 'st.global.v4' in ptx
# if dtype == 'float32' and allow_tf32:
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
# elif dtype == 'float32' and allow_tf32:
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
# elif dtype == 'int8':
# assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
def test_dot_without_load():

View File

@@ -1469,9 +1469,7 @@ def compile(fn, **kwargs):
import re
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
name, signature = match.group(1), match.group(2)
print(name, signature)
types = re.findall(arg_type_pattern[ir], signature)
print(types)
param_tys = [convert_type_repr(ty) for ty in types]
signature = {k: v for k, v in enumerate(param_tys)}
first_stage = list(stages.keys()).index(ir)