From 6c750b6856ad25e5252f8f209463a695f69ce183 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 8 Jan 2023 14:29:17 -0800 Subject: [PATCH] Added verifier for trans --- include/triton/Dialect/Triton/IR/Dialect.h | 4 + include/triton/Dialect/Triton/IR/TritonOps.td | 2 +- .../TritonToTritonGPUPass.cpp | 15 +- lib/Dialect/Triton/IR/Ops.cpp | 30 +++ lib/Dialect/TritonGPU/IR/Dialect.cpp | 16 ++ python/being-optimized.ttgir | 172 +++++++++--------- python/triton/compiler.py | 1 + python/tutorials/06-fused-attention.py | 59 +++--- python/unoptimized.ttgir | 144 ++++++++------- 9 files changed, 243 insertions(+), 200 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index fb4a64607..84b005cf7 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -25,6 +25,10 @@ class DialectInferLayoutInterface public: DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {} + virtual LogicalResult + inferTransOpEncoding(Attribute operandEncoding, + Attribute &resultEncoding) const = 0; + virtual LogicalResult inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, Attribute &resultEncoding) const = 0; diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index ef9597318..3668e8dcf 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -289,7 +289,7 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect, } def TT_TransOp : TT_Op<"trans", [NoSideEffect, - SameOperandsAndResultElementType]> { + DeclareOpInterfaceMethods]> { let summary = "transpose a tensor"; diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index ce5698289..8afd99210 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -319,20 +319,7 @@ struct TritonTransPattern : public OpConversionPattern { src = rewriter.create(src.getLoc(), srcType, src); } - auto srcSharedEncoding = - srcEncoding.cast(); - SmallVector retOrder(srcSharedEncoding.getOrder().begin(), - srcSharedEncoding.getOrder().end()); - SmallVector retShapes(srcType.getShape().begin(), - srcType.getShape().end()); - std::reverse(retOrder.begin(), retOrder.end()); - std::reverse(retShapes.begin(), retShapes.end()); - auto retEncoding = - triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, retOrder); - auto retType = - RankedTensorType::get(retShapes, srcType.getElementType(), retEncoding); - - rewriter.replaceOpWithNewOp(op, retType, src); + rewriter.replaceOpWithNewOp(op, src); return success(); } }; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index cba9e8b6b..99e752f2f 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -206,6 +206,36 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, state.addTypes({resultType}); } +//-- TransOp -- +mlir::LogicalResult mlir::triton::TransOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the input + auto argTy = operands[0].getType().cast(); + SmallVector retShape(argTy.getShape().begin(), + argTy.getShape().end()); + std::reverse(retShape.begin(), retShape.end()); + auto retEltTy = argTy.getElementType(); + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = + dyn_cast(&dialect); + if (inferLayoutInterface + ->inferTransOpEncoding(argEncoding, retEncoding) + .failed()) { + llvm::report_fatal_error("failed to infer layout for ReduceOp"); + return mlir::failure(); + } + } + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + return mlir::success(); + +} + //-- DotOp -- mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 958e3a1ee..c3ee3e50a 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -737,6 +737,22 @@ struct TritonGPUInferLayoutInterface return success(); } + LogicalResult + inferTransOpEncoding(Attribute operandEncoding, Attribute &resultEncoding) const { + SharedEncodingAttr sharedEncoding = operandEncoding.dyn_cast(); + if(!sharedEncoding) + return failure(); + SmallVector retOrder(sharedEncoding.getOrder().begin(), + sharedEncoding.getOrder().end()); + std::reverse(retOrder.begin(), retOrder.end()); + resultEncoding = SharedEncodingAttr::get(getDialect()->getContext(), + sharedEncoding.getVec(), + sharedEncoding.getPerPhase(), + sharedEncoding.getMaxPhase(), + retOrder); + return mlir::success(); + } + LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, Attribute &resultEncoding, diff --git a/python/being-optimized.ttgir b/python/being-optimized.ttgir index 466e82c55..9bb54690b 100644 --- a/python/being-optimized.ttgir +++ b/python/being-optimized.ttgir @@ -3,15 +3,14 @@ // TODO: reuse %128 in %137 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> // don't convert loaded value to mma for accumulation - #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}> #mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}> #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> -#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#shared2 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> module attributes {"triton_gpu.num-warps" = 8 : i32} { func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) { %c0 = arith.constant 0 : index @@ -21,7 +20,6 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1> %cst_0 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0> %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1> %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = arith.divsi %0, %arg22 : i32 %2 = arith.remsi %0, %arg22 : i32 @@ -84,94 +82,94 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %60 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> - %61 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1> - %62 = tt.broadcast %61 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> - %63 = arith.addi %62, %24 : tensor<128x64xi32, #blocked1> - %64 = tt.addptr %30, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %65 = tt.load %64 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> - %66 = triton_gpu.convert_layout %65 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> - %67 = arith.index_cast %47 : i32 to index - %68 = tt.trans %60 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> - %69 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> - %70 = tt.expand_dims %69 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0> - %71 = tt.broadcast %70 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0> - %72 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> - %73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2> - %74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2> - %75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2> - %76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> - %77 = tt.addptr %27, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %78 = tt.addptr %31, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %79:5 = scf.for %arg26 = %67 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>) { - %86 = arith.index_cast %arg26 : index to i32 - %87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0> - %88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> - %89 = arith.addi %87, %14 : tensor<128xi32, #blocked0> - %90 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> - %91 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> - %92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> - %93 = triton_gpu.convert_layout %68 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> - %94 = tt.dot %92, %93, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> - %95 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> - %96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0> - %97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0> - %98 = "triton_gpu.cmpi"(%97, %71) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0> - %99 = "triton_gpu.select"(%98, %94, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0> - %100 = tt.addptr %38, %89 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> - %101 = tt.load %100 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> - %102 = triton_gpu.convert_layout %101 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> - %103 = arith.mulf %99, %39 : tensor<128x128xf32, #mma0> - %104 = tt.expand_dims %102 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> - %105 = tt.broadcast %104 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> - %106 = arith.subf %103, %105 : tensor<128x128xf32, #mma0> - %107 = math.exp %106 : tensor<128x128xf32, #mma0> - %108 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> - %109 = arith.truncf %107 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> - %110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> - %111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> - %112 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> - %113 = triton_gpu.convert_layout %112 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> - %114 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %115 = tt.dot %114, %113, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> - %116 = tt.addptr %40, %89 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> - %117 = tt.load %116 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> - %118 = triton_gpu.convert_layout %117 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> - %119 = tt.expand_dims %118 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> - %120 = tt.broadcast %119 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> - %121 = arith.subf %cst_1, %120 : tensor<128x128xf32, #mma0> - %122 = triton_gpu.convert_layout %112 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> - %123 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> - %124 = tt.dot %122, %123, %121 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> - %125 = arith.mulf %107, %124 : tensor<128x128xf32, #mma0> - %126 = arith.mulf %125, %39 : tensor<128x128xf32, #mma0> - %127 = arith.truncf %126 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> - %128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> - %129 = tt.trans %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> - %130 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> - %131 = triton_gpu.convert_layout %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %132 = tt.dot %131, %130, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> - %133 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> - %135 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> - %136 = triton_gpu.convert_layout %135 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> - %137 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %138 = tt.dot %137, %136, %cst_2 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %61 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1> + %62 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1> + %63 = tt.broadcast %62 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> + %64 = arith.addi %63, %24 : tensor<128x64xi32, #blocked1> + %65 = tt.addptr %30, %64 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %66 = tt.load %65 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> + %67 = triton_gpu.convert_layout %66 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1> + %68 = arith.index_cast %47 : i32 to index + %69 = tt.trans %61 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared2> + %70 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> + %71 = tt.expand_dims %70 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0> + %72 = tt.broadcast %71 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0> + %73 = tt.trans %67 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared2> + %74 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2> + %75 = tt.broadcast %74 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2> + %76 = arith.addi %75, %26 : tensor<128x64xi32, #blocked2> + %77 = tt.addptr %32, %76 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> + %78 = tt.addptr %27, %64 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %79 = tt.addptr %31, %64 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %80:5 = scf.for %arg26 = %68 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %77, %arg30 = %78, %arg31 = %79) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>) { + %87 = arith.index_cast %arg26 : index to i32 + %88 = tt.splat %87 : (i32) -> tensor<128xi32, #blocked0> + %89 = tt.splat %87 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %90 = arith.addi %88, %14 : tensor<128xi32, #blocked0> + %91 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> + %92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> + %93 = triton_gpu.convert_layout %69 : (tensor<64x128xf16, #shared2>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> + %94 = triton_gpu.convert_layout %92 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> + %95 = tt.dot %94, %93, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> + %96 = arith.addi %89, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %97 = tt.expand_dims %96 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0> + %98 = tt.broadcast %97 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0> + %99 = "triton_gpu.cmpi"(%98, %72) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0> + %100 = "triton_gpu.select"(%99, %95, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0> + %101 = tt.addptr %38, %90 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> + %102 = tt.load %101 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> + %103 = triton_gpu.convert_layout %102 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %104 = arith.mulf %100, %39 : tensor<128x128xf32, #mma0> + %105 = tt.expand_dims %103 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> + %106 = tt.broadcast %105 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> + %107 = arith.subf %104, %106 : tensor<128x128xf32, #mma0> + %108 = math.exp %107 : tensor<128x128xf32, #mma0> + %109 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> + %110 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> + %111 = arith.truncf %108 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> + %112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1> + %113 = tt.trans %112 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared2> + %114 = triton_gpu.convert_layout %113 : (tensor<128x128xf16, #shared2>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %115 = triton_gpu.convert_layout %110 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %116 = tt.dot %114, %115, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %117 = tt.addptr %40, %90 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> + %118 = tt.load %117 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> + %119 = triton_gpu.convert_layout %118 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %120 = tt.expand_dims %119 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> + %121 = tt.broadcast %120 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> + %122 = arith.subf %cst_1, %121 : tensor<128x128xf32, #mma0> + %123 = triton_gpu.convert_layout %110 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> + %124 = triton_gpu.convert_layout %73 : (tensor<64x128xf16, #shared2>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> + %125 = tt.dot %123, %124, %122 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> + %126 = arith.mulf %108, %125 : tensor<128x128xf32, #mma0> + %127 = arith.mulf %126, %39 : tensor<128x128xf32, #mma0> + %128 = arith.truncf %127 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> + %129 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1> + %130 = tt.trans %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared2> + %131 = triton_gpu.convert_layout %92 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %132 = triton_gpu.convert_layout %130 : (tensor<128x128xf16, #shared2>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %133 = tt.dot %132, %131, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %134 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> + %135 = triton_gpu.convert_layout %134 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1> + %136 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %137 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %138 = tt.dot %136, %137, %135 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> %139 = triton_gpu.convert_layout %138 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2> - %1000 = arith.addf %133, %139: tensor<128x64xf32, #blocked2> - tt.store %arg29, %133 : tensor<128x64xf32, #blocked2> + tt.store %arg29, %139 : tensor<128x64xf32, #blocked2> %140 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> %141 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %142 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - scf.yield %115, %132, %140, %141, %142 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1> + scf.yield %116, %133, %140, %141, %142 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1> } - %80 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> - %81 = tt.addptr %44, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %82 = triton_gpu.convert_layout %80 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> - tt.store %81, %82 : tensor<128x64xf16, #blocked1> - %83 = arith.truncf %79#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> - %84 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %85 = triton_gpu.convert_layout %83 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> - tt.store %84, %85 : tensor<128x64xf16, #blocked1> + %81 = arith.truncf %80#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> + %82 = tt.addptr %44, %64 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> + tt.store %82, %83 : tensor<128x64xf16, #blocked1> + %84 = arith.truncf %80#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> + %85 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %86 = triton_gpu.convert_layout %84 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> + tt.store %85, %86 : tensor<128x64xf16, #blocked1> } return } -} \ No newline at end of file +} diff --git a/python/triton/compiler.py b/python/triton/compiler.py index e27184736..557df65a6 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -908,6 +908,7 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability): # pm.add_tritongpu_optimize_load_convert_pass() pm.add_tritongpu_sink_conversions_from_shared_pass() pm.add_tritongpu_decompose_conversions_to_dot_operand_pass() + pm.add_cse_pass() pm.run(mod) return mod diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 97b2482b0..cbf7187f6 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -191,7 +191,8 @@ def _bwd_kernel( tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) -_bwd_kernel = triton.compile("./being-optimized.ttgir", num_warps=8) +# _bwd_kernel = triton.compile("./being-optimized.ttgir", num_warps=8) +# _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8) # _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8) # _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432) @@ -259,36 +260,36 @@ class _attention(torch.autograd.Function): BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) - _bwd_kernel[(ctx.grid[1],1,1)]( - q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, - o.data_ptr(), do_scaled.data_ptr(), - dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), - l.data_ptr(), m.data_ptr(), - delta.data_ptr(), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - q.shape[0], q.shape[1], q.shape[2], - ctx.grid[0] - ) - - # pgm = _bwd_kernel[(ctx.grid[1],)]( - # q, k, v, ctx.sm_scale, - # o, do_scaled, - # dq, dk, dv, - # l, m, - # delta, - # q.stride(0), q.stride(1), q.stride(2), q.stride(3), - # k.stride(0), k.stride(1), k.stride(2), k.stride(3), - # v.stride(0), v.stride(1), v.stride(2), v.stride(3), + # _bwd_kernel[(ctx.grid[1],1,1)]( + # q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, + # o.data_ptr(), do_scaled.data_ptr(), + # dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), + # l.data_ptr(), m.data_ptr(), + # delta.data_ptr(), + # q.stride(0), q.stride(1), q.stride(2), + # k.stride(0), k.stride(1), k.stride(2), + # v.stride(0), v.stride(1), v.stride(2), # q.shape[0], q.shape[1], q.shape[2], - # ctx.grid[0], - # BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, - # BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, - # num_stages=1, + # ctx.grid[0] # ) - # print(pgm.asm["ttgir"]) - # exit() + + pgm = _bwd_kernel[(ctx.grid[1],)]( + q, k, v, ctx.sm_scale, + o, do_scaled, + dq, dk, dv, + l, m, + delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + q.shape[0], q.shape[1], q.shape[2], + ctx.grid[0], + BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + num_stages=1, + ) + print(pgm.asm["ttgir"]) + exit() return dq, dk, dv, None diff --git a/python/unoptimized.ttgir b/python/unoptimized.ttgir index 567e7a573..4b1361378 100644 --- a/python/unoptimized.ttgir +++ b/python/unoptimized.ttgir @@ -11,6 +11,7 @@ #mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}> #shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> +#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 8 : i32} { func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) { %c0 = arith.constant 0 : index @@ -81,91 +82,96 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %57 = arith.addi %56, %24 : tensor<128x64xi32, #blocked1> %58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> - %60 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1> - %61 = tt.broadcast %60 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> - %62 = arith.addi %61, %24 : tensor<128x64xi32, #blocked1> - %63 = tt.addptr %30, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %64 = tt.load %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> - %65 = arith.index_cast %47 : i32 to index - %66 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> - %67 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> - %68 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> - %69 = tt.expand_dims %68 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0> - %70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0> - %71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> - %72 = tt.trans %71 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> + %60 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> + %61 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1> + %62 = tt.broadcast %61 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> + %63 = arith.addi %62, %24 : tensor<128x64xi32, #blocked1> + %64 = tt.addptr %30, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %65 = tt.load %64 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> + %66 = triton_gpu.convert_layout %65 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> + %67 = arith.index_cast %47 : i32 to index + %68 = tt.trans %60 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> + %69 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> + %70 = tt.expand_dims %69 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0> + %71 = tt.broadcast %70 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0> + %72 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> %73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2> %74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2> %75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2> %76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> - %77 = tt.addptr %27, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %78 = tt.addptr %31, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %79 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> - %80:5 = scf.for %arg26 = %65 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>) { - %87 = arith.index_cast %arg26 : index to i32 - %88 = tt.splat %87 : (i32) -> tensor<128xi32, #blocked0> - %89 = tt.splat %87 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> - %90 = arith.addi %88, %14 : tensor<128xi32, #blocked0> - %91 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> - %92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> - %93 = triton_gpu.convert_layout %67 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> - %94 = tt.dot %92, %93, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> - %95 = arith.addi %89, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %77 = tt.addptr %27, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %78 = tt.addptr %31, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %79:5 = scf.for %arg26 = %67 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>) { + %86 = arith.index_cast %arg26 : index to i32 + %87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0> + %88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %89 = arith.addi %87, %14 : tensor<128xi32, #blocked0> + %90 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> + %91 = triton_gpu.convert_layout %68 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> + %92 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> + %93 = triton_gpu.convert_layout %92 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> + %94 = tt.dot %93, %91, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> + %95 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0> %97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0> - %98 = "triton_gpu.cmpi"(%97, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0> + %98 = "triton_gpu.cmpi"(%97, %71) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0> %99 = "triton_gpu.select"(%98, %94, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0> - %100 = tt.addptr %38, %90 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> + %100 = tt.addptr %38, %89 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> %101 = tt.load %100 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> - %102 = arith.mulf %99, %39 : tensor<128x128xf32, #mma0> - %103 = triton_gpu.convert_layout %101 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> - %104 = tt.expand_dims %103 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> + %102 = triton_gpu.convert_layout %101 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %103 = arith.mulf %99, %39 : tensor<128x128xf32, #mma0> + %104 = tt.expand_dims %102 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> %105 = tt.broadcast %104 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> - %106 = arith.subf %102, %105 : tensor<128x128xf32, #mma0> + %106 = arith.subf %103, %105 : tensor<128x128xf32, #mma0> %107 = math.exp %106 : tensor<128x128xf32, #mma0> %108 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %109 = arith.truncf %107 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> %110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> %111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> - %112 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> - %113 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %114 = tt.dot %113, %112, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> - %115 = tt.addptr %40, %90 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> - %116 = tt.load %115 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> - %117 = triton_gpu.convert_layout %116 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> - %118 = tt.expand_dims %117 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> - %119 = tt.broadcast %118 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> - %120 = arith.subf %cst_1, %119 : tensor<128x128xf32, #mma0> - %121 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> + %112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %113 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> + %114 = triton_gpu.convert_layout %113 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %115 = tt.dot %112, %114, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %116 = tt.addptr %40, %89 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> + %117 = tt.load %116 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> + %118 = triton_gpu.convert_layout %117 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %119 = tt.expand_dims %118 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> + %120 = tt.broadcast %119 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> + %121 = arith.subf %cst_1, %120 : tensor<128x128xf32, #mma0> %122 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> - %123 = tt.dot %121, %122, %120 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> - %124 = arith.mulf %107, %123 : tensor<128x128xf32, #mma0> - %125 = arith.mulf %124, %39 : tensor<128x128xf32, #mma0> - %126 = arith.truncf %125 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> - %127 = triton_gpu.convert_layout %126 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> - %128 = tt.trans %127 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> - %129 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> - %130 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %131 = tt.dot %130, %129, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> - %132 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> - %133 = triton_gpu.convert_layout %132 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1> - %134 = triton_gpu.convert_layout %126 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %135 = tt.dot %134, %79, %133 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> - %136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2> - tt.store %arg29, %136 : tensor<128x64xf32, #blocked2> - %137 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> - %138 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %139 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - scf.yield %114, %131, %137, %138, %139 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1> + %123 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> + %124 = triton_gpu.convert_layout %123 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> + %125 = tt.dot %124, %122, %121 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> + %126 = arith.mulf %107, %125 : tensor<128x128xf32, #mma0> + %127 = arith.mulf %126, %39 : tensor<128x128xf32, #mma0> + %128 = arith.truncf %127 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> + %129 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> + %130 = tt.trans %129 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> + %131 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> + %132 = triton_gpu.convert_layout %131 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %133 = triton_gpu.convert_layout %130 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %134 = tt.dot %133, %132, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %135 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> + %136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1> + %137 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %138 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> + %139 = triton_gpu.convert_layout %138 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %140 = tt.dot %137, %139, %136 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %141 = triton_gpu.convert_layout %140 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2> + tt.store %arg29, %141 : tensor<128x64xf32, #blocked2> + %142 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> + %143 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %144 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + scf.yield %115, %134, %142, %143, %144 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1> } - %81 = arith.truncf %80#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> - %82 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> - tt.store %82, %83 : tensor<128x64xf16, #blocked1> - %84 = arith.truncf %80#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> - %85 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %86 = triton_gpu.convert_layout %84 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> - tt.store %85, %86 : tensor<128x64xf16, #blocked1> + %80 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> + %81 = tt.addptr %44, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %82 = triton_gpu.convert_layout %80 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> + tt.store %81, %82 : tensor<128x64xf16, #blocked1> + %83 = arith.truncf %79#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> + %84 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %85 = triton_gpu.convert_layout %83 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> + tt.store %84, %85 : tensor<128x64xf16, #blocked1> } return }