diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 13d6dac8a..2b3aa239d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -483,6 +483,7 @@ public: return op->getBlock() == cvt->getBlock() && !(isa(op) && !op->getResult(0).getType().isa()) && + !isa(op) && !isa(op); }; mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter); diff --git a/python/bwd.ttgir b/python/bwd.ttgir index ffad904f2..874e9711c 100644 --- a/python/bwd.ttgir +++ b/python/bwd.ttgir @@ -1,16 +1,20 @@ -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}> -#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}> +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}> +#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}> #shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> +// TODO: swizzle +#shared1 = #triton_gpu.shared<{vec = 2, perPhase = 1, maxPhase = 1, order = [1, 0]}> + module attributes {"triton_gpu.num-warps" = 8 : i32} { func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) { - %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma1> - %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma0> + %cst = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1> %c128 = arith.constant 128 : index - %c0 = arith.constant 0 : index %c128_i32 = arith.constant 128 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = arith.divsi %0, %arg22 : i32 %2 = arith.remsi %0, %arg22 : i32 @@ -24,88 +28,136 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %10 = tt.addptr %arg6, %5 : !tt.ptr, i32 %11 = tt.addptr %arg7, %5 : !tt.ptr, i32 %12 = tt.addptr %arg8, %5 : !tt.ptr, i32 - %13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> - %14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %15 = tt.expand_dims %13 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<128x1xi32, #blocked0> - %16 = tt.expand_dims %14 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> - %17 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked0> - %18 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1> - %19 = arith.muli %15, %17 : tensor<128x1xi32, #blocked0> - %20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> - %21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %22 = tt.broadcast %19 : (tensor<128x1xi32, #blocked0>) -> tensor<128x64xi32, #blocked0> - %23 = tt.expand_dims %20 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x64xi32, #blocked0> - %24 = tt.broadcast %23 : (tensor<1x64xi32, #blocked0>) -> tensor<128x64xi32, #blocked0> - %25 = tt.expand_dims %21 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> - %26 = tt.broadcast %25 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> - %27 = arith.addi %22, %24 : tensor<128x64xi32, #blocked0> - %28 = tt.splat %6 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> - %29 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked1> - %30 = tt.splat %7 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> - %31 = tt.splat %8 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> - %32 = tt.splat %9 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> - %33 = tt.splat %10 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked0> - %34 = tt.addptr %33, %27 : tensor<128x64x!tt.ptr, #blocked0>, tensor<128x64xi32, #blocked0> - %35 = arith.muli %16, %29 : tensor<128x1xi32, #blocked1> - %36 = tt.broadcast %35 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> - %37 = arith.addi %36, %26 : tensor<128x64xi32, #blocked1> - %38 = tt.addptr %30, %37 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %39 = tt.load %38 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> - %40 = arith.muli %16, %18 : tensor<128x1xi32, #blocked1> - %41 = tt.broadcast %40 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> - %42 = arith.addi %41, %26 : tensor<128x64xi32, #blocked1> - %43 = tt.addptr %31, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %44 = tt.load %43 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> - %45 = arith.muli %arg24, %c128_i32 : i32 - %46 = arith.index_cast %45 : i32 to index - %47 = triton_gpu.convert_layout %39 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> - %48 = tt.trans %47 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> - %49 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma1> - %50 = triton_gpu.convert_layout %44 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> - %51 = tt.trans %50 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> - %52 = arith.muli %arg14, %c128_i32 : i32 - %53 = tt.splat %52 : (i32) -> tensor<128x64xi32, #blocked0> - %54 = tt.splat %52 : (i32) -> tensor<128x64xi32, #blocked1> - %55 = tt.addptr %28, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %56 = tt.addptr %32, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %57 = triton_gpu.convert_layout %48 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> - %58 = triton_gpu.convert_layout %51 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> - %59:5 = scf.for %arg25 = %c0 to %46 step %c128 iter_args(%arg26 = %cst_0, %arg27 = %cst_0, %arg28 = %34, %arg29 = %55, %arg30 = %56) -> (tensor<128x64xf32, #mma0>, tensor<128x64xf32, #mma0>, tensor<128x64x!tt.ptr, #blocked0>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>) { - %68 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> - %69 = triton_gpu.convert_layout %68 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %70 = tt.dot %69, %57, %cst {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x128xf32, #mma1> - %73 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> - %74 = arith.truncf %70 : tensor<128x128xf32, #mma1> to tensor<128x128xf16, #mma1> - %75 = triton_gpu.convert_layout %74 : (tensor<128x128xf16, #mma1>) -> tensor<128x128xf16, #shared1> - %76 = tt.trans %75 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0> - %77 = triton_gpu.convert_layout %76 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> - %78 = triton_gpu.convert_layout %73 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> - %79 = tt.dot %77, %78, %arg26 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x64xf32, #mma0> - %80 = triton_gpu.convert_layout %73 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %81 = tt.dot %80, %58, %cst {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x128xf32, #mma1> - %83 = arith.mulf %70, %81 : tensor<128x128xf32, #mma1> - %84 = arith.mulf %83, %49 : tensor<128x128xf32, #mma1> - %85 = arith.truncf %84 : tensor<128x128xf32, #mma1> to tensor<128x128xf16, #mma1> - %86 = triton_gpu.convert_layout %85 : (tensor<128x128xf16, #mma1>) -> tensor<128x128xf16, #shared1> - %87 = tt.trans %86 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0> - %88 = triton_gpu.convert_layout %87 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> - %89 = triton_gpu.convert_layout %68 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> - %90 = tt.dot %88, %89, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x64xf32, #mma0> - %91 = tt.addptr %arg28, %53 : tensor<128x64x!tt.ptr, #blocked0>, tensor<128x64xi32, #blocked0> - %92 = tt.addptr %arg29, %54 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %93 = tt.addptr %arg30, %54 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - scf.yield %79, %arg27, %arg28, %arg29, %arg30 : tensor<128x64xf32, #mma0>, tensor<128x64xf32, #mma0>, tensor<128x64x!tt.ptr, #blocked0>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1> + %13 = arith.index_cast %arg24 : i32 to index + %14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> + %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked0> + %21 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1> + %22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> + %23 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %24 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x64xi32, #blocked0> + %25 = tt.broadcast %24 : (tensor<1x64xi32, #blocked0>) -> tensor<128x64xi32, #blocked0> + %26 = tt.expand_dims %23 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %27 = tt.broadcast %26 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> + %28 = tt.splat %6 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked0> + %29 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked0> + %30 = tt.splat %7 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked0> + %31 = tt.splat %8 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked0> + %32 = tt.splat %9 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked0> + %33 = tt.splat %10 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> + %34 = arith.muli %arg24, %c128_i32 : i32 + %35 = arith.index_cast %34 : i32 to index + %36 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma0> + %37 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma0> + %38 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma0> + %39 = arith.muli %arg14, %c128_i32 : i32 + %40 = tt.splat %39 : (i32) -> tensor<128x64xi32, #blocked0> + %41 = tt.splat %39 : (i32) -> tensor<128x64xi32, #blocked1> + %42 = tt.splat %12 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked0> + %43 = tt.splat %11 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked0> + scf.for %arg25 = %c0 to %13 step %c1 { + %44 = arith.index_cast %arg25 : index to i32 + %45 = arith.muli %44, %c128_i32 : i32 + %46 = tt.splat %45 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %47 = tt.splat %45 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> + %48 = tt.splat %45 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> + %49 = tt.splat %45 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %50 = arith.addi %46, %14 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %51 = arith.addi %49, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %52 = tt.expand_dims %50 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<128x1xi32, #blocked0> + %53 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + %54 = arith.muli %52, %29 : tensor<128x1xi32, #blocked0> + %55 = tt.broadcast %54 : (tensor<128x1xi32, #blocked0>) -> tensor<128x64xi32, #blocked0> + %56 = arith.addi %55, %25 : tensor<128x64xi32, #blocked0> + %57 = tt.addptr %30, %56 : tensor<128x64x!tt.ptr, #blocked0>, tensor<128x64xi32, #blocked0> + %58 = tt.load %57 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked0> + %59 = arith.muli %52, %20 : tensor<128x1xi32, #blocked0> + %60 = tt.broadcast %59 : (tensor<128x1xi32, #blocked0>) -> tensor<128x64xi32, #blocked0> + %61 = arith.addi %60, %25 : tensor<128x64xi32, #blocked0> + %62 = tt.addptr %31, %61 : tensor<128x64x!tt.ptr, #blocked0>, tensor<128x64xi32, #blocked0> + %63 = tt.load %62 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked0> + %64 = arith.index_cast %45 : i32 to index + %65 = triton_gpu.convert_layout %58 : (tensor<128x64xf16, #blocked0>) -> tensor<128x64xf16, #shared0> + %66 = tt.trans %65 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> + %67 = arith.addi %47, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> + %68 = tt.expand_dims %67 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0> + %69 = tt.broadcast %68 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0> + %70 = arith.addi %48, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> + %71 = tt.expand_dims %70 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0> + %72 = tt.broadcast %71 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0> + %73 = triton_gpu.convert_layout %63 : (tensor<128x64xf16, #blocked0>) -> tensor<128x64xf16, #shared0> + %74 = tt.trans %73 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> + %75 = arith.muli %53, %21 : tensor<128x1xi32, #blocked1> + %76 = tt.broadcast %75 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> + %77 = arith.addi %76, %27 : tensor<128x64xi32, #blocked1> + %78 = tt.addptr %33, %77 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %79 = tt.addptr %28, %61 : tensor<128x64x!tt.ptr, #blocked0>, tensor<128x64xi32, #blocked0> + %80 = tt.addptr %32, %61 : tensor<128x64x!tt.ptr, #blocked0>, tensor<128x64xi32, #blocked0> + %81:5 = scf.for %arg26 = %64 to %35 step %c128 iter_args(%arg27 = %cst_1, %arg28 = %cst_1, %arg29 = %78, %arg30 = %79, %arg31 = %80) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked0>, tensor<128x64x!tt.ptr, #blocked0>) { + %88 = arith.index_cast %arg26 : index to i32 + %89 = tt.splat %88 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %90 = tt.splat %88 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %91 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked0> + %92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> + %93 = triton_gpu.convert_layout %66 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> + %94 = tt.dot %92, %93, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> + %95 = arith.addi %89, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0> + %97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0> + %98 = "triton_gpu.cmpi"(%97, %69) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0> + %99 = "triton_gpu.select"(%98, %94, %cst) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0> + %100 = arith.mulf %99, %36 : tensor<128x128xf32, #mma0> + %101 = math.exp %100 : tensor<128x128xf32, #mma0> + %102 = arith.addi %90, %19 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %103 = tt.expand_dims %102 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0> + %104 = tt.broadcast %103 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0> + %105 = "triton_gpu.cmpi"(%104, %72) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0> + %106 = "triton_gpu.select"(%105, %94, %cst) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0> + %107 = arith.mulf %106, %37 : tensor<128x128xf32, #mma0> + %108 = math.exp %107 : tensor<128x128xf32, #mma0> + %109 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked0> + %110 = arith.truncf %101 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> + %111 = triton_gpu.convert_layout %110 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1> + %112 = tt.trans %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0> + %113 = triton_gpu.convert_layout %112 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %114 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #blocked0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %115 = tt.dot %113, %114, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %116 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #blocked0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> + %117 = triton_gpu.convert_layout %74 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> + %118 = tt.dot %116, %117, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> + %119 = arith.mulf %108, %118 : tensor<128x128xf32, #mma0> + %120 = arith.mulf %119, %38 : tensor<128x128xf32, #mma0> + %121 = arith.truncf %120 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> + %122 = triton_gpu.convert_layout %121 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1> + %123 = tt.trans %122 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared0> + %124 = triton_gpu.convert_layout %123 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %125 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %126 = tt.dot %124, %125, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %127 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked1> + %128 = triton_gpu.convert_layout %127 : (tensor<128x64xf32, #blocked1>) -> tensor<128x64xf32, #mma1> + %129 = triton_gpu.convert_layout %121 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %130 = triton_gpu.convert_layout %58 : (tensor<128x64xf16, #blocked0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %131 = tt.dot %129, %130, %128 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %132 = triton_gpu.convert_layout %131 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked1> + tt.store %arg29, %132 : tensor<128x64xf32, #blocked1> + %133 = tt.addptr %arg29, %41 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %134 = tt.addptr %arg30, %40 : tensor<128x64x!tt.ptr, #blocked0>, tensor<128x64xi32, #blocked0> + %135 = tt.addptr %arg31, %40 : tensor<128x64x!tt.ptr, #blocked0>, tensor<128x64xi32, #blocked0> + scf.yield %115, %126, %133, %134, %135 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked0>, tensor<128x64x!tt.ptr, #blocked0> + } + %82 = triton_gpu.convert_layout %81#1 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked0> + %83 = triton_gpu.convert_layout %81#0 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked0> + %84 = tt.addptr %42, %61 : tensor<128x64x!tt.ptr, #blocked0>, tensor<128x64xi32, #blocked0> + %85 = arith.truncf %83 : tensor<128x64xf32, #blocked0> to tensor<128x64xf16, #blocked0> + tt.store %84, %85 : tensor<128x64xf16, #blocked0> + %86 = tt.addptr %43, %56 : tensor<128x64x!tt.ptr, #blocked0>, tensor<128x64xi32, #blocked0> + %87 = arith.truncf %82 : tensor<128x64xf32, #blocked0> to tensor<128x64xf16, #blocked0> + tt.store %86, %87 : tensor<128x64xf16, #blocked0> } - %60 = triton_gpu.convert_layout %59#1 : (tensor<128x64xf32, #mma0>) -> tensor<128x64xf32, #blocked1> - %61 = triton_gpu.convert_layout %59#0 : (tensor<128x64xf32, #mma0>) -> tensor<128x64xf32, #blocked1> - %62 = tt.splat %12 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> - %63 = tt.splat %11 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> - %64 = tt.addptr %62, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %65 = arith.truncf %61 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1> - tt.store %64, %65 : tensor<128x64xf16, #blocked1> - %66 = tt.addptr %63, %37 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %67 = arith.truncf %60 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1> - tt.store %66, %67 : tensor<128x64xf16, #blocked1> return } } \ No newline at end of file diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 0097da9ec..36a4687e3 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -133,68 +133,67 @@ def _bwd_kernel( DQ += off_z * stride_qz + off_h * stride_qh DK += off_z * stride_qz + off_h * stride_qh DV += off_z * stride_qz + off_h * stride_qh - # for start_n in range(0, num_block): - start_n = 0 - lo = start_n * BLOCK_M - # initialize row/col offsets - offs_qm = lo + tl.arange(0, BLOCK_M) - offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) - offs_m = tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_DMODEL) - # initialize pointers to value-like data - q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - # pointer to row-wise quantities in value-like data - # D_ptrs = D + off_hz * N_CTX - # m_ptrs = M + off_hz * N_CTX - # initialize dv amd dk - dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # k and v stay in SRAM throughout - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - # loop over rows - for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): - # offs_m_curr = start_m + offs_m - # load q, k, v, do on-chip - q = tl.load(q_ptrs) - # recompute p = softmax(qk, dim=-1).T - # NOTE: `do` is pre-divided by `l`; no normalization here - qk = tl.dot(q, tl.trans(k)) - # qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) - # m = tl.load(m_ptrs + offs_m_curr) - # p = tl.exp(qk * sm_scale - m[:, None]) - p = qk * sm_scale - # compute dv - do = tl.load(do_ptrs) - dv += tl.dot(tl.trans(p.to(tl.float16)), do) - # # compute dp = dot(v, do) - # Di = tl.load(D_ptrs + offs_m_curr) - # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - dp += tl.dot(do, tl.trans(v)) - # compute ds = p * (dp - delta[:, None]) - ds = p * dp * sm_scale - # # compute dk = dot(ds.T, q) - dk += tl.dot(tl.trans(ds.to(tl.float16)), q) - # # compute dq - # dq = tl.load(dq_ptrs) - # dq += tl.dot(ds.to(tl.float16), k) - # tl.store(dq_ptrs, dq) - # increment pointers - dq_ptrs += BLOCK_M * stride_qm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_qm - # write-back - dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, tl.trans(k)) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + # m = tl.load(m_ptrs + offs_m_curr) + # p = tl.exp(qk * sm_scale - m[:, None]) + p = tl.exp(qk * sm_scale) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(tl.trans(p.to(tl.float16)), do) + # compute dp = dot(v, do) + # Di = tl.load(D_ptrs + offs_m_curr) + # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + dp += tl.dot(do, tl.trans(v)) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds.to(tl.float16)), q) + # compute dq + dq = tl.load(dq_ptrs) + dq += tl.dot(ds.to(tl.float16), k) + tl.store(dq_ptrs, dq) + # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) -_bwd_kernel = triton.compile("./bwd.ptx", num_warps=8, shared=32768) +_bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8) # _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432) empty = torch.empty(128, device="cuda") @@ -288,7 +287,7 @@ class _attention(torch.autograd.Function): # num_stages=1, # ) # print(pgm.asm["ttgir"]) - exit(1) + # exit(1) return dq, dk, dv, None @@ -327,8 +326,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): # compare triton.testing.assert_almost_equal(ref_out, tri_out) triton.testing.assert_almost_equal(ref_dv, tri_dv) - triton.testing.assert_almost_equal(ref_dk, tri_dk) - triton.testing.assert_almost_equal(ref_dq, tri_dq) + # triton.testing.assert_almost_equal(ref_dk, tri_dk) + # triton.testing.assert_almost_equal(ref_dq, tri_dq) BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 @@ -379,4 +378,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) return ms -bench_flash_attention.run(save_path='.', print_data=True) \ No newline at end of file +# bench_flash_attention.run(save_path='.', print_data=True) \ No newline at end of file diff --git a/test/lib/Analysis/TestAllocation.cpp b/test/lib/Analysis/TestAllocation.cpp index a29465630..baff2ce73 100644 --- a/test/lib/Analysis/TestAllocation.cpp +++ b/test/lib/Analysis/TestAllocation.cpp @@ -28,7 +28,7 @@ struct TestAllocationPass if (scratchBufferId != Allocation::InvalidBufferId) { size_t offset = allocation.getOffset(scratchBufferId); size_t size = allocation.getAllocatedSize(scratchBufferId); - os << "scratch offset = " << offset << ", size = " << size << "\n"; + os << " scratch offset = " << offset << ", size = " << size << "\n"; } if (op->getNumResults() < 1) return; @@ -37,6 +37,7 @@ struct TestAllocationPass if (bufferId != Allocation::InvalidBufferId) { size_t offset = allocation.getOffset(bufferId); size_t size = allocation.getAllocatedSize(bufferId); + os << result << "\n"; os << "offset = " << offset << ", size = " << size << "\n"; } }