From a4ff0c362ce2eaa76007198b0e8af82f57dde8a8 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 6 Nov 2022 20:52:11 -0800 Subject: [PATCH] [FRONTEND] Fix issues with atomics (#849) --- include/triton/Dialect/Triton/IR/TritonOps.td | 13 ++++++++++--- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 14 ++++++++++++++ lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 4 ++++ lib/Dialect/TritonGPU/Transforms/Combine.cpp | 6 ++++-- 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index e089f82c4..f22e8b742 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -186,7 +186,14 @@ def TT_StoreOp : TT_Op<"store", // Atomic Op // def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape, - SameOperandsAndResultEncoding]> { + SameOperandsAndResultEncoding, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"infer ptr type from value type", + "val", "ptr", + "getPointerTypeSameShape($_self)">, + TypesMatchWith<"infer mask type from value type", + "val", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()">]> { let summary = "atomic rmw"; let description = [{ @@ -195,8 +202,8 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape, return old value at $ptr }]; - let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrTensor:$ptr, - TT_Type:$val, I1Tensor:$mask); + let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr, + TT_Type:$val, Optional:$mask); let results = (outs TT_Type:$result); } diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 706d10fed..a8e7ffeeb 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -278,6 +278,20 @@ struct TritonStorePattern : public OpConversionPattern { } }; +struct TritonAtomicRMWPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(op.getType()), adaptor.atomic_rmw_op(), + adaptor.ptr(), adaptor.val(), adaptor.mask()); + return success(); + } +}; + struct TritonExtElemwisePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index a0c2d23e8..85c082c24 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -118,6 +118,10 @@ struct CoalescePass : public TritonGPUCoalesceBase { builder.setInsertionPoint(curr); if (auto load = dyn_cast(curr)) coalesceOp(axisInfo, curr, load.ptr(), builder); + if (auto op = dyn_cast(curr)) + coalesceOp(axisInfo, curr, op.ptr(), builder); + if (auto op = dyn_cast(curr)) + coalesceOp(axisInfo, curr, op.ptr(), builder); if (auto load = dyn_cast(curr)) coalesceOp(axisInfo, curr, load.src(), builder); diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 456ce1200..88988be3a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -200,7 +200,7 @@ inline bool expensive_to_remat(Operation *op) { return true; if (isa(op)) + triton::AtomicRMWOp, triton::AtomicCASOp, triton::DotOp>(op)) return true; if (isa(op)) return true; @@ -478,7 +478,9 @@ public: SetVector cvtSlices; auto filter = [&](Operation *op) { - return isInLoop(op) && !isa(op) && + return isInLoop(op) && + !isa(op) && !isa(op) && !isa(op) && !isa(op); };