Added TTGIR kernel

This commit is contained in:
Phil Tillet
2022-12-27 21:49:28 -08:00
parent 0d6e6cf578
commit eefc9d1274
2 changed files with 179 additions and 13 deletions

View File

@@ -0,0 +1,157 @@
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func public @_fwd_kernel_0d1d2d34d5d6d7d8d9d10c11d12d13d14c15d16d17d18c19d20d21d22c2324d25d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32 {tt.divisibility = 16 : i32}) {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant dense<1.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
%cst_1 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma>
%cst_2 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
%cst_4 = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%c1_i32 = arith.constant 1 : i32
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c128_i32 = arith.constant 128 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = tt.get_program_id {axis = 1 : i32} : i32
%2 = arith.muli %0, %c128_i32 : i32
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0>
%4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%6 = tt.splat %2 : (i32) -> tensor<128xi32, #blocked0>
%7 = tt.splat %2 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%8 = tt.splat %2 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%9 = arith.addi %6, %3 : tensor<128xi32, #blocked0>
%10 = arith.muli %1, %arg8 : i32
%11 = arith.addi %7, %4 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%12 = arith.addi %8, %5 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%13 = tt.splat %arg9 : (i32) -> tensor<128x1xi32, #blocked1>
%14 = tt.splat %10 : (i32) -> tensor<128x1xi32, #blocked1>
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%16 = tt.expand_dims %15 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
%17 = tt.broadcast %16 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
%20 = tt.expand_dims %19 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>) -> tensor<1x128xi32, #mma>
%21 = tt.splat %arg12 : (i32) -> tensor<1x128xi32, #blocked2>
%22 = tt.splat %10 : (i32) -> tensor<1x128xi32, #blocked2>
%23 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%24 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xi32, #blocked2>
%25 = tt.expand_dims %18 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2>
%26 = arith.muli %25, %21 : tensor<1x128xi32, #blocked2>
%27 = arith.addi %22, %26 : tensor<1x128xi32, #blocked2>
%28 = tt.broadcast %27 : (tensor<1x128xi32, #blocked2>) -> tensor<64x128xi32, #blocked2>
%29 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
%30 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<64x128x!tt.ptr<f16>, #blocked2>
%31 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
%32 = tt.expand_dims %11 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
%33 = arith.muli %32, %13 : tensor<128x1xi32, #blocked1>
%34 = arith.addi %14, %33 : tensor<128x1xi32, #blocked1>
%35 = tt.broadcast %34 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
%36 = arith.addi %35, %17 : tensor<128x64xi32, #blocked1>
%37 = tt.addptr %29, %36 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%38 = arith.addi %0, %c1_i32 : i32
%39 = arith.muli %38, %c128_i32 : i32
%40 = arith.index_cast %39 : i32 to index
%41 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma>
%42 = tt.expand_dims %12 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xi32, #mma>
%43 = tt.broadcast %42 : (tensor<128x1xi32, #mma>) -> tensor<128x128xi32, #mma>
%44 = arith.muli %arg12, %c128_i32 : i32
%45 = tt.splat %44 : (i32) -> tensor<64x128xi32, #blocked2>
%46 = arith.muli %arg15, %c128_i32 : i32
%47 = tt.splat %46 : (i32) -> tensor<128x64xi32, #blocked1>
%48 = tt.broadcast %24 : (tensor<64x1xi32, #blocked2>) -> tensor<64x128xi32, #blocked2>
%49 = arith.addi %28, %48 : tensor<64x128xi32, #blocked2>
%50 = tt.addptr %30, %49 : tensor<64x128x!tt.ptr<f16>, #blocked2>, tensor<64x128xi32, #blocked2>
%51 = tt.expand_dims %4 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
%52 = arith.muli %51, %13 : tensor<128x1xi32, #blocked1>
%53 = arith.addi %14, %52 : tensor<128x1xi32, #blocked1>
%54 = tt.broadcast %53 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
%55 = arith.addi %54, %17 : tensor<128x64xi32, #blocked1>
%56 = tt.addptr %31, %55 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%79 = triton_gpu.alloc_tensor : tensor<1x128x64xf16, #shared0>
// TODO: Load should be transformed into `insert_slice_async + extract_slice` at the very end of the optimization pass so it benefits from LICM
%80 = triton_gpu.insert_slice_async %37, %79, %c0_i32 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64x!tt.ptr<f16>, #blocked1> -> tensor<1x128x64xf16, #shared0>
triton_gpu.async_wait {num = 0 : i32}
%81 = tensor.extract_slice %80[0, 0, 0] [1, 128, 64] [1, 1, 1] : tensor<1x128x64xf16, #shared0> to tensor<128x64xf16, #shared0>
%82 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
%57:5 = scf.for %arg22 = %c0 to %40 step %c128 iter_args(%arg23 = %cst_4, %arg24 = %cst_3, %arg25 = %cst_2, %arg26 = %50, %arg27 = %56) -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<64x128x!tt.ptr<f16>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
%78 = arith.index_cast %arg22 : index to i32
%83 = triton_gpu.alloc_tensor : tensor<1x64x128xf16, #shared1>
%84 = triton_gpu.insert_slice_async %arg26, %83, %c0_i32 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x128x!tt.ptr<f16>, #blocked2> -> tensor<1x64x128xf16, #shared1>
triton_gpu.async_wait {num = 0 : i32}
%85 = tensor.extract_slice %84[0, 0, 0] [1, 64, 128] [1, 1, 1] : tensor<1x64x128xf16, #shared1> to tensor<64x128xf16, #shared1>
%86 = triton_gpu.convert_layout %85 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>
%87 = tt.dot %82, %86, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x128xf32, #mma>
%88 = tt.splat %78 : (i32) -> tensor<1x128xi32, #mma>
%89 = arith.addi %88, %20 : tensor<1x128xi32, #mma>
%90 = tt.broadcast %89 : (tensor<1x128xi32, #mma>) -> tensor<128x128xi32, #mma>
%91 = arith.mulf %87, %41 : tensor<128x128xf32, #mma>
%92 = "triton_gpu.cmpi"(%43, %90) {predicate = 5 : i64} : (tensor<128x128xi32, #mma>, tensor<128x128xi32, #mma>) -> tensor<128x128xi1, #mma>
%93 = "triton_gpu.select"(%92, %91, %cst_1) : (tensor<128x128xi1, #mma>, tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma>) -> tensor<128x128xf32, #mma>
%94 = tt.reduce %93 {axis = 1 : i32, redOp = 12 : i32} : tensor<128x128xf32, #mma> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%95 = "triton_gpu.cmpf"(%94, %arg25) {predicate = 2 : i64} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%96 = "triton_gpu.select"(%95, %94, %arg25) : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%97 = tt.expand_dims %96 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
%98 = tt.broadcast %97 : (tensor<128x1xf32, #mma>) -> tensor<128x128xf32, #mma>
%99 = arith.subf %93, %98 : tensor<128x128xf32, #mma>
%100 = math.exp %99 : tensor<128x128xf32, #mma>
%101 = tt.reduce %100 {axis = 1 : i32, redOp = 2 : i32} : tensor<128x128xf32, #mma> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%102 = arith.subf %arg25, %96 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%103 = math.exp %102 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%104 = arith.mulf %arg23, %103 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%105 = arith.addf %101, %104 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%106 = arith.divf %cst, %105 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%107 = arith.mulf %104, %106 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%108 = tt.expand_dims %107 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
%109 = tt.broadcast %108 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma>
%110 = arith.mulf %arg24, %109 : tensor<128x64xf32, #mma>
%111 = tt.expand_dims %106 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma>
%112 = tt.broadcast %111 : (tensor<128x1xf32, #mma>) -> tensor<128x128xf32, #mma>
%113 = arith.mulf %100, %112 : tensor<128x128xf32, #mma>
%114 = arith.truncf %113 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
%115 = triton_gpu.convert_layout %114 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
%116 = triton_gpu.alloc_tensor : tensor<1x128x64xf16, #shared0>
%117 = triton_gpu.insert_slice_async %arg27, %116, %c0_i32 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64x!tt.ptr<f16>, #blocked1> -> tensor<1x128x64xf16, #shared0>
triton_gpu.async_wait {num = 0 : i32}
%118 = tensor.extract_slice %117[0, 0, 0] [1, 128, 64] [1, 1, 1] : tensor<1x128x64xf16, #shared0> to tensor<128x64xf16, #shared0>
%119 = triton_gpu.convert_layout %118 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>
%120 = tt.dot %115, %119, %110 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x64xf32, #mma>
%121 = tt.addptr %arg26, %45 : tensor<64x128x!tt.ptr<f16>, #blocked2>, tensor<64x128xi32, #blocked2>
%122 = tt.addptr %arg27, %47 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
scf.yield %105, %120, %96, %121, %122 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<64x128x!tt.ptr<f16>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>
}
%58 = triton_gpu.convert_layout %57#2 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked0>
%60 = triton_gpu.convert_layout %57#0 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked0>
%61 = arith.muli %1, %arg21 : i32
%62 = tt.addptr %arg4, %61 : !tt.ptr<f32>, i32
%63 = tt.splat %62 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
%64 = tt.addptr %63, %9 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
%65 = tt.addptr %arg5, %61 : !tt.ptr<f32>, i32
%66 = tt.splat %65 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
%67 = tt.addptr %66, %9 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
tt.store %64, %60 : tensor<128xf32, #blocked0>
tt.store %67, %58 : tensor<128xf32, #blocked0>
%68 = arith.muli %1, %arg17 : i32
%69 = tt.splat %arg18 : (i32) -> tensor<128x1xi32, #blocked1>
%70 = tt.splat %68 : (i32) -> tensor<128x1xi32, #blocked1>
%71 = tt.splat %arg6 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked1>
%72 = arith.muli %32, %69 : tensor<128x1xi32, #blocked1>
%73 = arith.addi %70, %72 : tensor<128x1xi32, #blocked1>
%74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
%75 = arith.addi %74, %17 : tensor<128x64xi32, #blocked1>
%76 = tt.addptr %71, %75 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%77 = arith.truncf %57#1 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
// TODO: conversion should be here, not right after the loop
%78 = triton_gpu.convert_layout %77 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #blocked1>
tt.store %76, %78 : tensor<128x64xf16, #blocked1>
return
}
}

View File

@@ -46,7 +46,6 @@ def _fwd_kernel(
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
@@ -192,6 +191,7 @@ def _bwd_kernel(
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
_fwd_kernel = triton.compile("./flash-attention.ttgir", num_warps=4)
empty = torch.empty(128, device="cuda")
@@ -210,19 +210,28 @@ class _attention(torch.autograd.Function):
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
# _fwd_kernel[grid](
# q, k, v, sm_scale,
# L, m,
# o,
# q.stride(0), q.stride(1), q.stride(2), q.stride(3),
# k.stride(0), k.stride(1), k.stride(2), k.stride(3),
# v.stride(0), v.stride(1), v.stride(2), v.stride(3),
# o.stride(0), o.stride(1), o.stride(2), o.stride(3),
# q.shape[0], q.shape[1], q.shape[2],
# BLOCK_M=BLOCK, BLOCK_N=BLOCK,
# BLOCK_DMODEL=Lk, num_warps=num_warps,
# num_stages=1,
# )
_fwd_kernel[grid](
q, k, v, sm_scale,
L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=1,
)
q.data_ptr(), k.data_ptr(), v.data_ptr(), sm_scale,
L.data_ptr(), m.data_ptr(),
o.data_ptr(),
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
o.stride(0), o.stride(1), o.stride(2),
q.shape[0], q.shape[1], q.shape[2])
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK