diff --git a/python/flash-attention.ttgir b/python/flash-attention.ttgir new file mode 100644 index 000000000..7ce40fbc7 --- /dev/null +++ b/python/flash-attention.ttgir @@ -0,0 +1,157 @@ +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + func public @_fwd_kernel_0d1d2d34d5d6d7d8d9d10c11d12d13d14c15d16d17d18c19d20d21d22c2324d25d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32 {tt.divisibility = 16 : i32}) { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<1.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_1 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma> + %cst_2 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %cst_4 = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %c1_i32 = arith.constant 1 : i32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = tt.get_program_id {axis = 1 : i32} : i32 + %2 = arith.muli %0, %c128_i32 : i32 + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> + %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %6 = tt.splat %2 : (i32) -> tensor<128xi32, #blocked0> + %7 = tt.splat %2 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %8 = tt.splat %2 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %9 = arith.addi %6, %3 : tensor<128xi32, #blocked0> + %10 = arith.muli %1, %arg8 : i32 + %11 = arith.addi %7, %4 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %12 = arith.addi %8, %5 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %13 = tt.splat %arg9 : (i32) -> tensor<128x1xi32, #blocked1> + %14 = tt.splat %10 : (i32) -> tensor<128x1xi32, #blocked1> + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %16 = tt.expand_dims %15 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %17 = tt.broadcast %16 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %20 = tt.expand_dims %19 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>) -> tensor<1x128xi32, #mma> + %21 = tt.splat %arg12 : (i32) -> tensor<1x128xi32, #blocked2> + %22 = tt.splat %10 : (i32) -> tensor<1x128xi32, #blocked2> + %23 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %24 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xi32, #blocked2> + %25 = tt.expand_dims %18 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + %26 = arith.muli %25, %21 : tensor<1x128xi32, #blocked2> + %27 = arith.addi %22, %26 : tensor<1x128xi32, #blocked2> + %28 = tt.broadcast %27 : (tensor<1x128xi32, #blocked2>) -> tensor<64x128xi32, #blocked2> + %29 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> + %30 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x128x!tt.ptr, #blocked2> + %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> + %32 = tt.expand_dims %11 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + %33 = arith.muli %32, %13 : tensor<128x1xi32, #blocked1> + %34 = arith.addi %14, %33 : tensor<128x1xi32, #blocked1> + %35 = tt.broadcast %34 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> + %36 = arith.addi %35, %17 : tensor<128x64xi32, #blocked1> + %37 = tt.addptr %29, %36 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %38 = arith.addi %0, %c1_i32 : i32 + %39 = arith.muli %38, %c128_i32 : i32 + %40 = arith.index_cast %39 : i32 to index + %41 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma> + %42 = tt.expand_dims %12 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xi32, #mma> + %43 = tt.broadcast %42 : (tensor<128x1xi32, #mma>) -> tensor<128x128xi32, #mma> + %44 = arith.muli %arg12, %c128_i32 : i32 + %45 = tt.splat %44 : (i32) -> tensor<64x128xi32, #blocked2> + %46 = arith.muli %arg15, %c128_i32 : i32 + %47 = tt.splat %46 : (i32) -> tensor<128x64xi32, #blocked1> + %48 = tt.broadcast %24 : (tensor<64x1xi32, #blocked2>) -> tensor<64x128xi32, #blocked2> + %49 = arith.addi %28, %48 : tensor<64x128xi32, #blocked2> + %50 = tt.addptr %30, %49 : tensor<64x128x!tt.ptr, #blocked2>, tensor<64x128xi32, #blocked2> + %51 = tt.expand_dims %4 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + %52 = arith.muli %51, %13 : tensor<128x1xi32, #blocked1> + %53 = arith.addi %14, %52 : tensor<128x1xi32, #blocked1> + %54 = tt.broadcast %53 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> + %55 = arith.addi %54, %17 : tensor<128x64xi32, #blocked1> + %56 = tt.addptr %31, %55 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %79 = triton_gpu.alloc_tensor : tensor<1x128x64xf16, #shared0> + + // TODO: Load should be transformed into `insert_slice_async + extract_slice` at the very end of the optimization pass so it benefits from LICM + %80 = triton_gpu.insert_slice_async %37, %79, %c0_i32 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64x!tt.ptr, #blocked1> -> tensor<1x128x64xf16, #shared0> + triton_gpu.async_wait {num = 0 : i32} + %81 = tensor.extract_slice %80[0, 0, 0] [1, 128, 64] [1, 1, 1] : tensor<1x128x64xf16, #shared0> to tensor<128x64xf16, #shared0> + %82 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + + %57:5 = scf.for %arg22 = %c0 to %40 step %c128 iter_args(%arg23 = %cst_4, %arg24 = %cst_3, %arg25 = %cst_2, %arg26 = %50, %arg27 = %56) -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<64x128x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>) { + %78 = arith.index_cast %arg22 : index to i32 + %83 = triton_gpu.alloc_tensor : tensor<1x64x128xf16, #shared1> + %84 = triton_gpu.insert_slice_async %arg26, %83, %c0_i32 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x128x!tt.ptr, #blocked2> -> tensor<1x64x128xf16, #shared1> + triton_gpu.async_wait {num = 0 : i32} + %85 = tensor.extract_slice %84[0, 0, 0] [1, 64, 128] [1, 1, 1] : tensor<1x64x128xf16, #shared1> to tensor<64x128xf16, #shared1> + %86 = triton_gpu.convert_layout %85 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> + %87 = tt.dot %82, %86, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x128xf32, #mma> + %88 = tt.splat %78 : (i32) -> tensor<1x128xi32, #mma> + %89 = arith.addi %88, %20 : tensor<1x128xi32, #mma> + %90 = tt.broadcast %89 : (tensor<1x128xi32, #mma>) -> tensor<128x128xi32, #mma> + %91 = arith.mulf %87, %41 : tensor<128x128xf32, #mma> + %92 = "triton_gpu.cmpi"(%43, %90) {predicate = 5 : i64} : (tensor<128x128xi32, #mma>, tensor<128x128xi32, #mma>) -> tensor<128x128xi1, #mma> + %93 = "triton_gpu.select"(%92, %91, %cst_1) : (tensor<128x128xi1, #mma>, tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma>) -> tensor<128x128xf32, #mma> + %94 = tt.reduce %93 {axis = 1 : i32, redOp = 12 : i32} : tensor<128x128xf32, #mma> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %95 = "triton_gpu.cmpf"(%94, %arg25) {predicate = 2 : i64} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %96 = "triton_gpu.select"(%95, %94, %arg25) : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %97 = tt.expand_dims %96 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma> + %98 = tt.broadcast %97 : (tensor<128x1xf32, #mma>) -> tensor<128x128xf32, #mma> + %99 = arith.subf %93, %98 : tensor<128x128xf32, #mma> + %100 = math.exp %99 : tensor<128x128xf32, #mma> + %101 = tt.reduce %100 {axis = 1 : i32, redOp = 2 : i32} : tensor<128x128xf32, #mma> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %102 = arith.subf %arg25, %96 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %103 = math.exp %102 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %104 = arith.mulf %arg23, %103 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %105 = arith.addf %101, %104 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %106 = arith.divf %cst, %105 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %107 = arith.mulf %104, %106 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %108 = tt.expand_dims %107 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma> + %109 = tt.broadcast %108 : (tensor<128x1xf32, #mma>) -> tensor<128x64xf32, #mma> + %110 = arith.mulf %arg24, %109 : tensor<128x64xf32, #mma> + %111 = tt.expand_dims %106 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xf32, #mma> + %112 = tt.broadcast %111 : (tensor<128x1xf32, #mma>) -> tensor<128x128xf32, #mma> + %113 = arith.mulf %100, %112 : tensor<128x128xf32, #mma> + %114 = arith.truncf %113 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %115 = triton_gpu.convert_layout %114 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %116 = triton_gpu.alloc_tensor : tensor<1x128x64xf16, #shared0> + %117 = triton_gpu.insert_slice_async %arg27, %116, %c0_i32 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64x!tt.ptr, #blocked1> -> tensor<1x128x64xf16, #shared0> + triton_gpu.async_wait {num = 0 : i32} + %118 = tensor.extract_slice %117[0, 0, 0] [1, 128, 64] [1, 1, 1] : tensor<1x128x64xf16, #shared0> to tensor<128x64xf16, #shared0> + %119 = triton_gpu.convert_layout %118 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> + %120 = tt.dot %115, %119, %110 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x64xf32, #mma> + %121 = tt.addptr %arg26, %45 : tensor<64x128x!tt.ptr, #blocked2>, tensor<64x128xi32, #blocked2> + %122 = tt.addptr %arg27, %47 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + scf.yield %105, %120, %96, %121, %122 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<64x128x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1> + } + %58 = triton_gpu.convert_layout %57#2 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked0> + %60 = triton_gpu.convert_layout %57#0 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked0> + %61 = arith.muli %1, %arg21 : i32 + %62 = tt.addptr %arg4, %61 : !tt.ptr, i32 + %63 = tt.splat %62 : (!tt.ptr) -> tensor<128x!tt.ptr, #blocked0> + %64 = tt.addptr %63, %9 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> + %65 = tt.addptr %arg5, %61 : !tt.ptr, i32 + %66 = tt.splat %65 : (!tt.ptr) -> tensor<128x!tt.ptr, #blocked0> + %67 = tt.addptr %66, %9 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> + tt.store %64, %60 : tensor<128xf32, #blocked0> + tt.store %67, %58 : tensor<128xf32, #blocked0> + %68 = arith.muli %1, %arg17 : i32 + %69 = tt.splat %arg18 : (i32) -> tensor<128x1xi32, #blocked1> + %70 = tt.splat %68 : (i32) -> tensor<128x1xi32, #blocked1> + %71 = tt.splat %arg6 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> + %72 = arith.muli %32, %69 : tensor<128x1xi32, #blocked1> + %73 = arith.addi %70, %72 : tensor<128x1xi32, #blocked1> + %74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> + %75 = arith.addi %74, %17 : tensor<128x64xi32, #blocked1> + %76 = tt.addptr %71, %75 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %77 = arith.truncf %57#1 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + // TODO: conversion should be here, not right after the loop + %78 = triton_gpu.convert_layout %77 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #blocked1> + tt.store %76, %78 : tensor<128x64xf16, #blocked1> + return + } +} diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index fa8562166..d10932959 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -46,7 +46,6 @@ def _fwd_kernel( q = tl.load(q_ptrs) # loop over k, v and update accumulator for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(k_ptrs) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -192,6 +191,7 @@ def _bwd_kernel( tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) +_fwd_kernel = triton.compile("./flash-attention.ttgir", num_warps=4) empty = torch.empty(128, device="cuda") @@ -210,19 +210,28 @@ class _attention(torch.autograd.Function): m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 + # _fwd_kernel[grid]( + # q, k, v, sm_scale, + # L, m, + # o, + # q.stride(0), q.stride(1), q.stride(2), q.stride(3), + # k.stride(0), k.stride(1), k.stride(2), k.stride(3), + # v.stride(0), v.stride(1), v.stride(2), v.stride(3), + # o.stride(0), o.stride(1), o.stride(2), o.stride(3), + # q.shape[0], q.shape[1], q.shape[2], + # BLOCK_M=BLOCK, BLOCK_N=BLOCK, + # BLOCK_DMODEL=Lk, num_warps=num_warps, + # num_stages=1, + # ) _fwd_kernel[grid]( - q, k, v, sm_scale, - L, m, - o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], q.shape[2], - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=Lk, num_warps=num_warps, - num_stages=1, - ) + q.data_ptr(), k.data_ptr(), v.data_ptr(), sm_scale, + L.data_ptr(), m.data_ptr(), + o.data_ptr(), + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + o.stride(0), o.stride(1), o.stride(2), + q.shape[0], q.shape[1], q.shape[2]) ctx.save_for_backward(q, k, v, o, L, m) ctx.BLOCK = BLOCK