From 8ebb593bbb7243ca4d44b24b562557e6c58a7226 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 9 Jan 2023 15:45:06 -0800 Subject: [PATCH] more work --- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 54 ++++++ .../DecomposeConversionsToDotOperand.cpp | 9 +- python/being-optimized.ttgir | 182 +++++++++--------- python/triton/compiler.py | 1 + python/tutorials/06-fused-attention.py | 58 +++--- 5 files changed, 179 insertions(+), 125 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 2c3489c76..5d165c09e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -1255,6 +1255,59 @@ public: } }; +// Convert + trans + convert +// x = convert_layout distributed -> #shared_x +// y = trans x -> #shared_y +// z = convert_layout y -> #dot_operand +class ConvertTransConvert : public mlir::RewritePattern { + +public: + ConvertTransConvert(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), + 1, context) {} + + LogicalResult matchAndRewrite(mlir::Operation* op, + mlir::PatternRewriter &rewriter) const override { + auto dstOp = cast(op); + auto tmpOp = dyn_cast_or_null(dstOp.src().getDefiningOp()); + if(!tmpOp) + return mlir::failure(); + auto srcOp = dyn_cast_or_null(tmpOp.src().getDefiningOp()); + if(!srcOp) + return mlir::failure(); + auto arg = srcOp.src(); + auto X = tmpOp.src(); + auto Y = dstOp.src(); + // types + auto argType = arg.getType().cast(); + auto XType = X.getType().cast(); + auto YType = Y.getType().cast(); + auto ZType = dstOp.getResult().getType().cast(); + // encodings + auto argEncoding = argType.getEncoding(); + auto XEncoding = XType.getEncoding().cast(); + auto YEncoding = YType.getEncoding().cast(); + auto ZEncoding = ZType.getEncoding().dyn_cast(); + if(!ZEncoding) + return mlir::failure(); + // new X encoding + auto newXOrder = triton::gpu::getOrder(argEncoding); + auto newXEncoding = triton::gpu::SharedEncodingAttr::get( + getContext(), ZEncoding, XType.getShape(), newXOrder, + XType.getElementType()); + auto newXType = RankedTensorType::get(XType.getShape(), XType.getElementType(), + newXEncoding); + if(XEncoding == newXEncoding) + return mlir::failure(); + + + auto newX = rewriter.create(srcOp.getLoc(), newXType, arg); + auto newY = rewriter.create(tmpOp.getLoc(), newX); + rewriter.replaceOpWithNewOp(dstOp, ZType, newY); + return mlir::success(); + } +}; + // Correct the versionMinor field in MmaEncodingAttr for Volta. class UpdateMMAVersionMinorForVolta : public mlir::RewritePattern { const DenseMap &mmaToUpdate; @@ -1423,6 +1476,7 @@ public: patterns.add(context); patterns.add(context); patterns.add(context, computeCapability); + patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); diff --git a/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp b/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp index d7d571019..29fc123c5 100644 --- a/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp +++ b/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp @@ -34,16 +34,17 @@ public: OpBuilder builder(cvtOp); auto srcType = cvtOp.getOperand().getType().cast(); auto dstType = cvtOp.getType().cast(); - auto srcBlocked = - srcType.getEncoding().dyn_cast(); + auto srcEncoding = srcType.getEncoding(); + if(srcEncoding.isa()) + return; auto dstDotOp = dstType.getEncoding().dyn_cast(); - if (srcBlocked && dstDotOp) { + if (dstDotOp) { auto tmpType = RankedTensorType::get( dstType.getShape(), dstType.getElementType(), triton::gpu::SharedEncodingAttr::get( mod.getContext(), dstDotOp, srcType.getShape(), - getOrder(srcBlocked), srcType.getElementType())); + triton::gpu::getOrder(srcEncoding), srcType.getElementType())); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getOperand()); auto newConvert = builder.create( diff --git a/python/being-optimized.ttgir b/python/being-optimized.ttgir index 9bb54690b..270fdcf5a 100644 --- a/python/being-optimized.ttgir +++ b/python/being-optimized.ttgir @@ -9,17 +9,16 @@ #mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}> #mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}> #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> -#shared2 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> module attributes {"triton_gpu.num-warps" = 8 : i32} { func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c128_i32 = arith.constant 128 : i32 - %c128 = arith.constant 128 : index - %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1> + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0> %cst_0 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1> + %c128 = arith.constant 128 : index + %c128_i32 = arith.constant 128 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = arith.divsi %0, %arg22 : i32 %2 = arith.remsi %0, %arg22 : i32 @@ -82,93 +81,92 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %60 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> - %61 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1> - %62 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1> - %63 = tt.broadcast %62 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> - %64 = arith.addi %63, %24 : tensor<128x64xi32, #blocked1> - %65 = tt.addptr %30, %64 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %66 = tt.load %65 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> - %67 = triton_gpu.convert_layout %66 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1> - %68 = arith.index_cast %47 : i32 to index - %69 = tt.trans %61 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared2> - %70 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> - %71 = tt.expand_dims %70 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0> - %72 = tt.broadcast %71 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0> - %73 = tt.trans %67 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared2> - %74 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2> - %75 = tt.broadcast %74 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2> - %76 = arith.addi %75, %26 : tensor<128x64xi32, #blocked2> - %77 = tt.addptr %32, %76 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> - %78 = tt.addptr %27, %64 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %79 = tt.addptr %31, %64 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %80:5 = scf.for %arg26 = %68 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %77, %arg30 = %78, %arg31 = %79) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>) { - %87 = arith.index_cast %arg26 : index to i32 - %88 = tt.splat %87 : (i32) -> tensor<128xi32, #blocked0> - %89 = tt.splat %87 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> - %90 = arith.addi %88, %14 : tensor<128xi32, #blocked0> - %91 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> - %92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> - %93 = triton_gpu.convert_layout %69 : (tensor<64x128xf16, #shared2>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> - %94 = triton_gpu.convert_layout %92 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> - %95 = tt.dot %94, %93, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> - %96 = arith.addi %89, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> - %97 = tt.expand_dims %96 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0> - %98 = tt.broadcast %97 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0> - %99 = "triton_gpu.cmpi"(%98, %72) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0> - %100 = "triton_gpu.select"(%99, %95, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0> - %101 = tt.addptr %38, %90 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> - %102 = tt.load %101 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> - %103 = triton_gpu.convert_layout %102 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> - %104 = arith.mulf %100, %39 : tensor<128x128xf32, #mma0> - %105 = tt.expand_dims %103 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> - %106 = tt.broadcast %105 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> - %107 = arith.subf %104, %106 : tensor<128x128xf32, #mma0> - %108 = math.exp %107 : tensor<128x128xf32, #mma0> - %109 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> - %110 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> - %111 = arith.truncf %108 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> - %112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1> - %113 = tt.trans %112 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared2> - %114 = triton_gpu.convert_layout %113 : (tensor<128x128xf16, #shared2>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %115 = triton_gpu.convert_layout %110 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> - %116 = tt.dot %114, %115, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> - %117 = tt.addptr %40, %90 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> - %118 = tt.load %117 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> - %119 = triton_gpu.convert_layout %118 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> - %120 = tt.expand_dims %119 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> - %121 = tt.broadcast %120 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> - %122 = arith.subf %cst_1, %121 : tensor<128x128xf32, #mma0> - %123 = triton_gpu.convert_layout %110 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> - %124 = triton_gpu.convert_layout %73 : (tensor<64x128xf16, #shared2>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> - %125 = tt.dot %123, %124, %122 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> - %126 = arith.mulf %108, %125 : tensor<128x128xf32, #mma0> - %127 = arith.mulf %126, %39 : tensor<128x128xf32, #mma0> - %128 = arith.truncf %127 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> - %129 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1> - %130 = tt.trans %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared2> - %131 = triton_gpu.convert_layout %92 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> - %132 = triton_gpu.convert_layout %130 : (tensor<128x128xf16, #shared2>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %133 = tt.dot %132, %131, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> - %134 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> - %135 = triton_gpu.convert_layout %134 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1> - %136 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %137 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> - %138 = tt.dot %136, %137, %135 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> - %139 = triton_gpu.convert_layout %138 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2> - tt.store %arg29, %139 : tensor<128x64xf32, #blocked2> - %140 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> - %141 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %142 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - scf.yield %116, %133, %140, %141, %142 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1> + %61 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1> + %62 = tt.broadcast %61 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> + %63 = arith.addi %62, %24 : tensor<128x64xi32, #blocked1> + %64 = tt.addptr %30, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %65 = tt.load %64 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> + %66 = triton_gpu.convert_layout %65 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> + %67 = arith.index_cast %47 : i32 to index + %68 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> + %69 = tt.expand_dims %68 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0> + %70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0> + %71 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2> + %72 = tt.broadcast %71 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2> + %73 = arith.addi %72, %26 : tensor<128x64xi32, #blocked2> + %74 = tt.addptr %32, %73 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> + %75 = tt.addptr %27, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %76 = tt.addptr %31, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %77:5 = scf.for %arg26 = %67 to %37 step %c128 iter_args(%arg27 = %cst_1, %arg28 = %cst_1, %arg29 = %74, %arg30 = %75, %arg31 = %76) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>) { + %84 = arith.index_cast %arg26 : index to i32 + %85 = tt.splat %84 : (i32) -> tensor<128xi32, #blocked0> + %86 = tt.splat %84 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %87 = arith.addi %85, %14 : tensor<128xi32, #blocked0> + %88 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> + %89 = triton_gpu.convert_layout %88 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> + %90 = tt.trans %60 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> + %91 = triton_gpu.convert_layout %89 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> + %92 = triton_gpu.convert_layout %90 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> + %93 = tt.dot %91, %92, %cst {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> + %94 = arith.addi %86, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %95 = tt.expand_dims %94 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0> + %96 = tt.broadcast %95 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0> + %97 = "triton_gpu.cmpi"(%96, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0> + %98 = "triton_gpu.select"(%97, %93, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0> + %99 = tt.addptr %38, %87 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> + %100 = tt.load %99 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> + %101 = triton_gpu.convert_layout %100 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %102 = arith.mulf %98, %39 : tensor<128x128xf32, #mma0> + %103 = tt.expand_dims %101 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> + %104 = tt.broadcast %103 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> + %105 = arith.subf %102, %104 : tensor<128x128xf32, #mma0> + %106 = math.exp %105 : tensor<128x128xf32, #mma0> + %107 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> + %108 = triton_gpu.convert_layout %107 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> + %109 = arith.truncf %106 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> + %110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> + %111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> + %112 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %113 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %114 = tt.dot %113, %112, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %115 = tt.addptr %40, %87 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> + %116 = tt.load %115 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> + %117 = triton_gpu.convert_layout %116 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %118 = tt.expand_dims %117 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> + %119 = tt.broadcast %118 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> + %120 = arith.subf %cst, %119 : tensor<128x128xf32, #mma0> + %121 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> + %122 = triton_gpu.convert_layout %121 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> + %123 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> + %124 = tt.dot %123, %122, %120 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> + %125 = arith.mulf %106, %124 : tensor<128x128xf32, #mma0> + %126 = arith.mulf %125, %39 : tensor<128x128xf32, #mma0> + %127 = arith.truncf %126 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> + %128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> + %129 = tt.trans %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> + %130 = triton_gpu.convert_layout %89 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %131 = triton_gpu.convert_layout %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %132 = tt.dot %131, %130, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %133 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> + %134 = triton_gpu.convert_layout %133 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1> + %135 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %136 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %137 = tt.dot %135, %136, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2> + tt.store %arg29, %138 : tensor<128x64xf32, #blocked2> + %139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> + %140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + scf.yield %114, %132, %139, %140, %141 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1> } - %81 = arith.truncf %80#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> - %82 = tt.addptr %44, %64 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %78 = arith.truncf %77#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> + %79 = tt.addptr %44, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %80 = triton_gpu.convert_layout %78 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> + tt.store %79, %80 : tensor<128x64xf16, #blocked1> + %81 = arith.truncf %77#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> + %82 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> tt.store %82, %83 : tensor<128x64xf16, #blocked1> - %84 = arith.truncf %80#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> - %85 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %86 = triton_gpu.convert_layout %84 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> - tt.store %85, %86 : tensor<128x64xf16, #blocked1> } return } diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 557df65a6..14f87d7a0 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -909,6 +909,7 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability): pm.add_tritongpu_sink_conversions_from_shared_pass() pm.add_tritongpu_decompose_conversions_to_dot_operand_pass() pm.add_cse_pass() + pm.add_symbol_dce_pass() pm.run(mod) return mod diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index cbf7187f6..5a9a1d72f 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -191,7 +191,7 @@ def _bwd_kernel( tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) -# _bwd_kernel = triton.compile("./being-optimized.ttgir", num_warps=8) +_bwd_kernel = triton.compile("./being-optimized.ttgir", num_warps=8) # _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8) # _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8) # _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432) @@ -260,36 +260,36 @@ class _attention(torch.autograd.Function): BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) - # _bwd_kernel[(ctx.grid[1],1,1)]( - # q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, - # o.data_ptr(), do_scaled.data_ptr(), - # dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), - # l.data_ptr(), m.data_ptr(), - # delta.data_ptr(), - # q.stride(0), q.stride(1), q.stride(2), - # k.stride(0), k.stride(1), k.stride(2), - # v.stride(0), v.stride(1), v.stride(2), - # q.shape[0], q.shape[1], q.shape[2], - # ctx.grid[0] - # ) - - pgm = _bwd_kernel[(ctx.grid[1],)]( - q, k, v, ctx.sm_scale, - o, do_scaled, - dq, dk, dv, - l, m, - delta, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), + _bwd_kernel[(ctx.grid[1],1,1)]( + q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, + o.data_ptr(), do_scaled.data_ptr(), + dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), + l.data_ptr(), m.data_ptr(), + delta.data_ptr(), + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), q.shape[0], q.shape[1], q.shape[2], - ctx.grid[0], - BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, - num_stages=1, + ctx.grid[0] ) - print(pgm.asm["ttgir"]) - exit() + + # pgm = _bwd_kernel[(ctx.grid[1],)]( + # q, k, v, ctx.sm_scale, + # o, do_scaled, + # dq, dk, dv, + # l, m, + # delta, + # q.stride(0), q.stride(1), q.stride(2), q.stride(3), + # k.stride(0), k.stride(1), k.stride(2), k.stride(3), + # v.stride(0), v.stride(1), v.stride(2), v.stride(3), + # q.shape[0], q.shape[1], q.shape[2], + # ctx.grid[0], + # BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, + # BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + # num_stages=1, + # ) + # print(pgm.asm["ttgir"]) + # exit() return dq, dk, dv, None