[FRONTEND] Fix issues with atomics (#849)
This commit is contained in:
@@ -186,7 +186,14 @@ def TT_StoreOp : TT_Op<"store",
|
|||||||
// Atomic Op
|
// Atomic Op
|
||||||
//
|
//
|
||||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
|
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
|
||||||
SameOperandsAndResultEncoding]> {
|
SameOperandsAndResultEncoding,
|
||||||
|
MemoryEffects<[MemWrite]>,
|
||||||
|
TypesMatchWith<"infer ptr type from value type",
|
||||||
|
"val", "ptr",
|
||||||
|
"getPointerTypeSameShape($_self)">,
|
||||||
|
TypesMatchWith<"infer mask type from value type",
|
||||||
|
"val", "mask", "getI1SameShape($_self)",
|
||||||
|
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||||
let summary = "atomic rmw";
|
let summary = "atomic rmw";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@@ -195,8 +202,8 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
|
|||||||
return old value at $ptr
|
return old value at $ptr
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrTensor:$ptr,
|
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr,
|
||||||
TT_Type:$val, I1Tensor:$mask);
|
TT_Type:$val, Optional<TT_BoolLike>:$mask);
|
||||||
|
|
||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
}
|
}
|
||||||
|
@@ -278,6 +278,20 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TritonAtomicRMWPattern
|
||||||
|
: public OpConversionPattern<triton::AtomicRMWOp> {
|
||||||
|
using OpConversionPattern<triton::AtomicRMWOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<triton::AtomicRMWOp>(
|
||||||
|
op, typeConverter->convertType(op.getType()), adaptor.atomic_rmw_op(),
|
||||||
|
adaptor.ptr(), adaptor.val(), adaptor.mask());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct TritonExtElemwisePattern
|
struct TritonExtElemwisePattern
|
||||||
: public OpConversionPattern<triton::ExtElemwiseOp> {
|
: public OpConversionPattern<triton::ExtElemwiseOp> {
|
||||||
using OpConversionPattern<triton::ExtElemwiseOp>::OpConversionPattern;
|
using OpConversionPattern<triton::ExtElemwiseOp>::OpConversionPattern;
|
||||||
|
@@ -118,6 +118,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
|||||||
builder.setInsertionPoint(curr);
|
builder.setInsertionPoint(curr);
|
||||||
if (auto load = dyn_cast<triton::LoadOp>(curr))
|
if (auto load = dyn_cast<triton::LoadOp>(curr))
|
||||||
coalesceOp<triton::LoadOp>(axisInfo, curr, load.ptr(), builder);
|
coalesceOp<triton::LoadOp>(axisInfo, curr, load.ptr(), builder);
|
||||||
|
if (auto op = dyn_cast<triton::AtomicRMWOp>(curr))
|
||||||
|
coalesceOp<triton::AtomicRMWOp>(axisInfo, curr, op.ptr(), builder);
|
||||||
|
if (auto op = dyn_cast<triton::AtomicCASOp>(curr))
|
||||||
|
coalesceOp<triton::AtomicCASOp>(axisInfo, curr, op.ptr(), builder);
|
||||||
if (auto load = dyn_cast<triton::gpu::InsertSliceAsyncOp>(curr))
|
if (auto load = dyn_cast<triton::gpu::InsertSliceAsyncOp>(curr))
|
||||||
coalesceOp<triton::gpu::InsertSliceAsyncOp>(axisInfo, curr, load.src(),
|
coalesceOp<triton::gpu::InsertSliceAsyncOp>(axisInfo, curr, load.src(),
|
||||||
builder);
|
builder);
|
||||||
|
@@ -200,7 +200,7 @@ inline bool expensive_to_remat(Operation *op) {
|
|||||||
return true;
|
return true;
|
||||||
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||||
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
|
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
|
||||||
triton::DotOp>(op))
|
triton::AtomicRMWOp, triton::AtomicCASOp, triton::DotOp>(op))
|
||||||
return true;
|
return true;
|
||||||
if (isa<scf::YieldOp, scf::ForOp>(op))
|
if (isa<scf::YieldOp, scf::ForOp>(op))
|
||||||
return true;
|
return true;
|
||||||
@@ -478,7 +478,9 @@ public:
|
|||||||
|
|
||||||
SetVector<Operation *> cvtSlices;
|
SetVector<Operation *> cvtSlices;
|
||||||
auto filter = [&](Operation *op) {
|
auto filter = [&](Operation *op) {
|
||||||
return isInLoop(op) && !isa<triton::LoadOp>(op) &&
|
return isInLoop(op) &&
|
||||||
|
!isa<triton::LoadOp, triton::StoreOp, triton::AtomicRMWOp,
|
||||||
|
triton::AtomicCASOp>(op) &&
|
||||||
!isa<triton::DotOp>(op) && !isa<scf::YieldOp>(op) &&
|
!isa<triton::DotOp>(op) && !isa<scf::YieldOp>(op) &&
|
||||||
!isa<triton::gpu::ConvertLayoutOp>(op);
|
!isa<triton::gpu::ConvertLayoutOp>(op);
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user