[FRONTEND] Fix issues with atomics (#849)

This commit is contained in:
Philippe Tillet
2022-11-06 20:52:11 -08:00
committed by GitHub
parent b6dbe959f0
commit a4ff0c362c
4 changed files with 32 additions and 5 deletions

View File

@@ -278,6 +278,20 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
}
};
struct TritonAtomicRMWPattern
: public OpConversionPattern<triton::AtomicRMWOp> {
using OpConversionPattern<triton::AtomicRMWOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::AtomicRMWOp>(
op, typeConverter->convertType(op.getType()), adaptor.atomic_rmw_op(),
adaptor.ptr(), adaptor.val(), adaptor.mask());
return success();
}
};
struct TritonExtElemwisePattern
: public OpConversionPattern<triton::ExtElemwiseOp> {
using OpConversionPattern<triton::ExtElemwiseOp>::OpConversionPattern;

View File

@@ -118,6 +118,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
builder.setInsertionPoint(curr);
if (auto load = dyn_cast<triton::LoadOp>(curr))
coalesceOp<triton::LoadOp>(axisInfo, curr, load.ptr(), builder);
if (auto op = dyn_cast<triton::AtomicRMWOp>(curr))
coalesceOp<triton::AtomicRMWOp>(axisInfo, curr, op.ptr(), builder);
if (auto op = dyn_cast<triton::AtomicCASOp>(curr))
coalesceOp<triton::AtomicCASOp>(axisInfo, curr, op.ptr(), builder);
if (auto load = dyn_cast<triton::gpu::InsertSliceAsyncOp>(curr))
coalesceOp<triton::gpu::InsertSliceAsyncOp>(axisInfo, curr, load.src(),
builder);

View File

@@ -200,7 +200,7 @@ inline bool expensive_to_remat(Operation *op) {
return true;
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp,
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
triton::DotOp>(op))
triton::AtomicRMWOp, triton::AtomicCASOp, triton::DotOp>(op))
return true;
if (isa<scf::YieldOp, scf::ForOp>(op))
return true;
@@ -478,7 +478,9 @@ public:
SetVector<Operation *> cvtSlices;
auto filter = [&](Operation *op) {
return isInLoop(op) && !isa<triton::LoadOp>(op) &&
return isInLoop(op) &&
!isa<triton::LoadOp, triton::StoreOp, triton::AtomicRMWOp,
triton::AtomicCASOp>(op) &&
!isa<triton::DotOp>(op) && !isa<scf::YieldOp>(op) &&
!isa<triton::gpu::ConvertLayoutOp>(op);
};