diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index ab9fabe2d..1a601bdd0 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -72,24 +72,24 @@ void storeDistributedToShared(Value src, Value llSrc, Value staIdx1 = i32_val(0); Value stride0 = dstStrides[outOrd[0]]; Value stride1 = dstStrides[outOrd[1]]; - // if (auto addOp = dyn_cast(dynIdx0.getDefiningOp())) - // if (auto cstRhs = - // dyn_cast(addOp.getRhs().getDefiningOp())) { - // unsigned rhsVal = - // cstRhs.getValue().cast().getValue().getSExtValue(); - // unsigned key = (rhsVal / outVec) % maxPhase; - // if (cache.find(key) == cache.end()) - // cache[key] = dynIdx0; - // dynIdx0 = cache[key]; - // staIdx0 = - // i32_val((rhsVal) / (outVec * maxPhase) * (outVec * maxPhase)); - // } - // if (auto addOp = dyn_cast(dynIdx1.getDefiningOp())) - // if (auto cstRhs = - // dyn_cast(addOp.getRhs().getDefiningOp())) { - // dynIdx1 = addOp.getLhs(); - // staIdx1 = addOp.getRhs(); - // } + if (auto addOp = dyn_cast(dynIdx0.getDefiningOp())) + if (auto cstRhs = + dyn_cast(addOp.getRhs().getDefiningOp())) { + unsigned rhsVal = + cstRhs.getValue().cast().getValue().getSExtValue(); + unsigned key = (rhsVal / outVec) % maxPhase; + if (cache.find(key) == cache.end()) + cache[key] = dynIdx0; + dynIdx0 = cache[key]; + staIdx0 = + i32_val((rhsVal) / (outVec * maxPhase) * (outVec * maxPhase)); + } + if (auto addOp = dyn_cast(dynIdx1.getDefiningOp())) + if (auto cstRhs = + dyn_cast(addOp.getRhs().getDefiningOp())) { + dynIdx1 = addOp.getLhs(); + staIdx1 = addOp.getRhs(); + } // offset along non-contiguous dimension Value off1 = mul(dynIdx1, stride1); diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 5d165c09e..3a15808b7 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -1266,14 +1266,16 @@ public: : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, context) {} - LogicalResult matchAndRewrite(mlir::Operation* op, - mlir::PatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { auto dstOp = cast(op); auto tmpOp = dyn_cast_or_null(dstOp.src().getDefiningOp()); - if(!tmpOp) + if (!tmpOp) return mlir::failure(); - auto srcOp = dyn_cast_or_null(tmpOp.src().getDefiningOp()); - if(!srcOp) + auto srcOp = dyn_cast_or_null( + tmpOp.src().getDefiningOp()); + if (!srcOp) return mlir::failure(); auto arg = srcOp.src(); auto X = tmpOp.src(); @@ -1285,25 +1287,74 @@ public: auto ZType = dstOp.getResult().getType().cast(); // encodings auto argEncoding = argType.getEncoding(); - auto XEncoding = XType.getEncoding().cast(); - auto YEncoding = YType.getEncoding().cast(); - auto ZEncoding = ZType.getEncoding().dyn_cast(); - if(!ZEncoding) + auto XEncoding = + XType.getEncoding().cast(); + auto YEncoding = + YType.getEncoding().cast(); + auto ZEncoding = + ZType.getEncoding().dyn_cast(); + if (!ZEncoding) return mlir::failure(); // new X encoding auto newXOrder = triton::gpu::getOrder(argEncoding); auto newXEncoding = triton::gpu::SharedEncodingAttr::get( getContext(), ZEncoding, XType.getShape(), newXOrder, XType.getElementType()); - auto newXType = RankedTensorType::get(XType.getShape(), XType.getElementType(), - newXEncoding); - if(XEncoding == newXEncoding) + auto newXType = RankedTensorType::get(XType.getShape(), + XType.getElementType(), newXEncoding); + if (XEncoding == newXEncoding) return mlir::failure(); - - auto newX = rewriter.create(srcOp.getLoc(), newXType, arg); + auto newX = rewriter.create(srcOp.getLoc(), + newXType, arg); auto newY = rewriter.create(tmpOp.getLoc(), newX); - rewriter.replaceOpWithNewOp(dstOp, ZType, newY); + rewriter.replaceOpWithNewOp(dstOp, ZType, + newY); + return mlir::success(); + } +}; + +// +class ConvertDotConvert : public mlir::RewritePattern { +public: + ConvertDotConvert(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), + 1, context) {} + + LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto dstOp = cast(op); + auto dotOp = dyn_cast_or_null(dstOp.src().getDefiningOp()); + if (!dotOp) + return mlir::failure(); + if (std::distance(dstOp->user_begin(), dstOp->user_end()) != 1 || + std::distance(dotOp->user_begin(), dotOp->user_end()) != 1) + return mlir::failure(); + auto cvtOp = dyn_cast_or_null( + dotOp.getOperand(2).getDefiningOp()); + if (!cvtOp) + return mlir::failure(); + auto loadOp = dyn_cast_or_null(cvtOp.src().getDefiningOp()); + if (!loadOp) + return mlir::failure(); + auto dstTy = dstOp.getResult().getType().cast(); + auto srcTy = cvtOp.getOperand().getType().cast(); + if (dstTy != srcTy) + return mlir::failure(); + + // TODO: int tensor cores + auto _0f = rewriter.create( + op->getLoc(), APFloat(0.0f), dstTy.getElementType().cast()); + auto _0 = rewriter.create( + op->getLoc(), dotOp.getResult().getType(), _0f); + auto newDot = rewriter.create( + op->getLoc(), dotOp.getResult().getType(), dotOp.getOperand(0), + dotOp.getOperand(1), _0, dotOp.allowTF32()); + auto newCvt = rewriter.create( + op->getLoc(), dstTy, newDot.getResult()); + auto newAdd = rewriter.replaceOpWithNewOp( + op, newCvt, cvtOp.getOperand()); return mlir::success(); } }; @@ -1477,6 +1528,7 @@ public: patterns.add(context); patterns.add(context, computeCapability); patterns.add(context); + patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); diff --git a/python/slow.ttgir b/python/slow.ttgir index f8ae009ca..3a26af9a6 100644 --- a/python/slow.ttgir +++ b/python/slow.ttgir @@ -148,7 +148,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %136 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %137 = tt.dot %135, %136, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> %138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2> - tt.store %arg29, %138 : tensor<128x64xf32, #blocked2> + tt.store %arg29, %133 : tensor<128x64xf32, #blocked2> %139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> %140 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %141 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 8892b5529..c961c9a62 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -191,6 +191,7 @@ def _bwd_kernel( tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) + # _bwd_kernel = triton.compile("./slow.ttgir", num_warps=8) # _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8) # _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8) @@ -260,7 +261,7 @@ class _attention(torch.autograd.Function): BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) - # _bwd_kernel[(ctx.grid[1],1,1)]( + # _bwd_kernel[(ctx.grid[1], 1, 1)]( # q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, # o.data_ptr(), do_scaled.data_ptr(), # dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),