diff --git a/python/chain-dot.ttgir b/python/chain-dot.ttgir new file mode 100644 index 000000000..8f14ad4b2 --- /dev/null +++ b/python/chain-dot.ttgir @@ -0,0 +1,55 @@ +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}> +#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + func public @kernel_0d1d2c3d4d5c6d7d8c9d10d11c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma0> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma1> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0> + %2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked0> + %3 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked0> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x64xi32, #blocked0> + %6 = tt.broadcast %5 : (tensor<1x64xi32, #blocked0>) -> tensor<64x64xi32, #blocked0> + %7 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0> + %8 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked0> + %9 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked0> + %10 = tt.splat %arg4 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked0> + %11 = tt.splat %arg7 : (i32) -> tensor<64x1xi32, #blocked0> + %12 = tt.splat %arg6 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked0> + %13 = arith.muli %1, %2 : tensor<64x1xi32, #blocked0> + %14 = tt.addptr %3, %13 : tensor<64x1x!tt.ptr, #blocked0>, tensor<64x1xi32, #blocked0> + %15 = tt.broadcast %14 : (tensor<64x1x!tt.ptr, #blocked0>) -> tensor<64x64x!tt.ptr, #blocked0> + %16 = tt.addptr %15, %6 : tensor<64x64x!tt.ptr, #blocked0>, tensor<64x64xi32, #blocked0> + %17 = tt.load %16 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked0> + %18 = arith.muli %1, %7 : tensor<64x1xi32, #blocked0> + %19 = tt.addptr %8, %18 : tensor<64x1x!tt.ptr, #blocked0>, tensor<64x1xi32, #blocked0> + %20 = tt.broadcast %19 : (tensor<64x1x!tt.ptr, #blocked0>) -> tensor<64x64x!tt.ptr, #blocked0> + %21 = tt.addptr %20, %6 : tensor<64x64x!tt.ptr, #blocked0>, tensor<64x64xi32, #blocked0> + %22 = tt.load %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked0> + %23 = triton_gpu.convert_layout %17 : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> + %24 = triton_gpu.convert_layout %22 : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> + %25 = tt.dot %23, %24, %cst {allowTF32 = false} : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<64x64xf32, #mma0> + %27 = arith.muli %1, %9 : tensor<64x1xi32, #blocked0> + %28 = tt.addptr %10, %27 : tensor<64x1x!tt.ptr, #blocked0>, tensor<64x1xi32, #blocked0> + %29 = tt.broadcast %28 : (tensor<64x1x!tt.ptr, #blocked0>) -> tensor<64x64x!tt.ptr, #blocked0> + %30 = tt.addptr %29, %6 : tensor<64x64x!tt.ptr, #blocked0>, tensor<64x64xi32, #blocked0> + %31 = tt.load %30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked0> + %32 = arith.truncf %25 : tensor<64x64xf32, #mma0> to tensor<64x64xf16, #mma0> + %133 = triton_gpu.convert_layout %32 : (tensor<64x64xf16, #mma0>) -> tensor<64x64xf16, #shared> + %33 = triton_gpu.convert_layout %133 : (tensor<64x64xf16, #shared>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %34 = triton_gpu.convert_layout %31 : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %35 = tt.dot %33, %34, %cst_0 {allowTF32 = true} : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<64x64xf32, #mma1> + %36 = triton_gpu.convert_layout %35 : (tensor<64x64xf32, #mma1>) -> tensor<64x64xf32, #blocked0> + %37 = arith.muli %1, %11 : tensor<64x1xi32, #blocked0> + %38 = tt.addptr %12, %37 : tensor<64x1x!tt.ptr, #blocked0>, tensor<64x1xi32, #blocked0> + %39 = tt.broadcast %38 : (tensor<64x1x!tt.ptr, #blocked0>) -> tensor<64x64x!tt.ptr, #blocked0> + %40 = tt.addptr %39, %6 : tensor<64x64x!tt.ptr, #blocked0>, tensor<64x64xi32, #blocked0> + %41 = arith.truncf %36 : tensor<64x64xf32, #blocked0> to tensor<64x64xf16, #blocked0> + tt.store %40, %41 : tensor<64x64xf16, #blocked0> + return + } +} \ No newline at end of file diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 862687f6f..359b847e7 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1177,19 +1177,25 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi z_tri = to_triton(z, device=device) if epilogue == 'trans': z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) - pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), - y_tri, y_tri.stride(0), y_tri.stride(1), - w_tri, w_tri.stride(0), w_tri.stride(1), - z_tri, z_tri.stride(0), z_tri.stride(1), - COL_A=col_a, COL_B=col_b, - BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, - ADD_MATRIX=epilogue == 'add-matrix', - ADD_ROWS=epilogue == 'add-rows', - ADD_COLS=epilogue == 'add-cols', - DO_SOFTMAX=epilogue == 'softmax', - CHAIN_DOT=epilogue == 'chain-dot', - ALLOW_TF32=allow_tf32, - num_warps=num_warps) + # pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), + # y_tri, y_tri.stride(0), y_tri.stride(1), + # w_tri, w_tri.stride(0), w_tri.stride(1), + # z_tri, z_tri.stride(0), z_tri.stride(1), + # COL_A=col_a, COL_B=col_b, + # BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, + # ADD_MATRIX=epilogue == 'add-matrix', + # ADD_ROWS=epilogue == 'add-rows', + # ADD_COLS=epilogue == 'add-cols', + # DO_SOFTMAX=epilogue == 'softmax', + # CHAIN_DOT=epilogue == 'chain-dot', + # ALLOW_TF32=allow_tf32, + # num_warps=num_warps) + kernel = triton.compile("./chain-dot.ttgir", num_warps=num_warps) + pgm = kernel[(1, 1, 1)](x_tri.data_ptr(), x_tri.stride(0), + y_tri.data_ptr(), y_tri.stride(0), + w_tri.data_ptr(), w_tri.stride(0), + z_tri.data_ptr(), z_tri.stride(0)) + # torch result if dtype == 'int8': z_ref = np.matmul(x.astype(np.float32), @@ -1217,15 +1223,15 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi else: np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) # make sure ld/st are vectorized - ptx = pgm.asm['ptx'] - assert 'ld.global.v4' in ptx - assert 'st.global.v4' in ptx - if dtype == 'float32' and allow_tf32: - assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx - elif dtype == 'float32' and allow_tf32: - assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx - elif dtype == 'int8': - assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx + # ptx = pgm.asm['ptx'] + # assert 'ld.global.v4' in ptx + # assert 'st.global.v4' in ptx + # if dtype == 'float32' and allow_tf32: + # assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx + # elif dtype == 'float32' and allow_tf32: + # assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx + # elif dtype == 'int8': + # assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx def test_dot_without_load(): diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 96ebfc3cf..810aff478 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1469,9 +1469,7 @@ def compile(fn, **kwargs): import re match = re.search(prototype_pattern[ir], src, re.MULTILINE) name, signature = match.group(1), match.group(2) - print(name, signature) types = re.findall(arg_type_pattern[ir], signature) - print(types) param_tys = [convert_type_repr(ty) for ty in types] signature = {k: v for k, v in enumerate(param_tys)} first_stage = list(stages.keys()).index(ir)