[TRITON-MLIR][BACKEND]support atomic_cas (#914)
1. support atomics-cas 2. add xchg support in atomic_rmw Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -59,7 +59,8 @@ def TT_AtomicRMWAttr : I32EnumAttr<
|
|||||||
I32EnumAttrCase<"MAX", 6, "max">,
|
I32EnumAttrCase<"MAX", 6, "max">,
|
||||||
I32EnumAttrCase<"MIN", 7, "min">,
|
I32EnumAttrCase<"MIN", 7, "min">,
|
||||||
I32EnumAttrCase<"UMAX", 8, "umax">,
|
I32EnumAttrCase<"UMAX", 8, "umax">,
|
||||||
I32EnumAttrCase<"UMIN", 9, "umin">
|
I32EnumAttrCase<"UMIN", 9, "umin">,
|
||||||
|
I32EnumAttrCase<"XCHG", 10, "exch">
|
||||||
]> {
|
]> {
|
||||||
let cppNamespace = "::mlir::triton";
|
let cppNamespace = "::mlir::triton";
|
||||||
}
|
}
|
||||||
|
@@ -115,9 +115,9 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
|||||||
// TODO: extend beyond scalars
|
// TODO: extend beyond scalars
|
||||||
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
|
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
|
||||||
SmallVector<unsigned> smemShape;
|
SmallVector<unsigned> smemShape;
|
||||||
auto ptrTy = op.ptr().getType();
|
if (op.ptr().getType().isa<RankedTensorType>()) {
|
||||||
if (auto tensorType = ptrTy.dyn_cast<RankedTensorType>()) {
|
// do nothing or just assert because shared memory is not used in tensor up
|
||||||
// do nothing or just assert because shared memory is not used in tensor
|
// to now
|
||||||
} else {
|
} else {
|
||||||
// need only bytes for scalar
|
// need only bytes for scalar
|
||||||
// always vec = 1 and elemsPerThread = 1 for scalar?
|
// always vec = 1 and elemsPerThread = 1 for scalar?
|
||||||
@@ -126,6 +126,10 @@ SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
|
|||||||
return smemShape;
|
return smemShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<unsigned> getScratchConfigForAtomicCAS(triton::AtomicCASOp op) {
|
||||||
|
return SmallVector<unsigned>{1};
|
||||||
|
}
|
||||||
|
|
||||||
class AllocationAnalysis {
|
class AllocationAnalysis {
|
||||||
public:
|
public:
|
||||||
AllocationAnalysis(Operation *operation, Allocation *allocation)
|
AllocationAnalysis(Operation *operation, Allocation *allocation)
|
||||||
@@ -230,6 +234,17 @@ private:
|
|||||||
: elems * elemTy.getIntOrFloatBitWidth() / 8;
|
: elems * elemTy.getIntOrFloatBitWidth() / 8;
|
||||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||||
}
|
}
|
||||||
|
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
|
||||||
|
auto value = op->getOperand(0);
|
||||||
|
auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp);
|
||||||
|
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||||
|
std::multiplies{});
|
||||||
|
auto elemTy =
|
||||||
|
value.getType().cast<triton::PointerType>().getPointeeType();
|
||||||
|
auto bytes = elemTy.isa<triton::PointerType>()
|
||||||
|
? elems * kPtrBitWidth / 8
|
||||||
|
: elems * elemTy.getIntOrFloatBitWidth() / 8;
|
||||||
|
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -6036,6 +6036,82 @@ struct ExpOpConversionApprox
|
|||||||
return ptxBuilder.launch(rewriter, loc, f32_ty, false);
|
return ptxBuilder.launch(rewriter, loc, f32_ty, false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
/// ====================== atomic_cas codegen begin ==========================
|
||||||
|
struct AtomicCASOpConversion
|
||||||
|
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>,
|
||||||
|
public LoadStoreConversionBase {
|
||||||
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
|
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
|
AtomicCASOpConversion(LLVMTypeConverter &converter,
|
||||||
|
const Allocation *allocation, Value smem,
|
||||||
|
AxisInfoAnalysis &axisAnalysisPass,
|
||||||
|
PatternBenefit benefit)
|
||||||
|
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>(
|
||||||
|
converter, allocation, smem, benefit),
|
||||||
|
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
MLIRContext *ctx = rewriter.getContext();
|
||||||
|
Value ptr = op.ptr();
|
||||||
|
|
||||||
|
Value llPtr = adaptor.ptr();
|
||||||
|
Value llCmp = adaptor.cmp();
|
||||||
|
Value llVal = adaptor.val();
|
||||||
|
|
||||||
|
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
||||||
|
auto cmpElements = getElementsFromStruct(loc, llCmp, rewriter);
|
||||||
|
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
|
||||||
|
|
||||||
|
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||||
|
Type valueElemTy =
|
||||||
|
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
|
||||||
|
: op.getResult().getType();
|
||||||
|
auto tid = tid_val();
|
||||||
|
Value pred = icmp_eq(tid, i32_val(0));
|
||||||
|
PTXBuilder ptxBuilderMemfence;
|
||||||
|
auto memfenc = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
|
||||||
|
memfenc();
|
||||||
|
auto ASMReturnTy = void_ty(ctx);
|
||||||
|
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||||
|
|
||||||
|
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||||
|
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||||
|
|
||||||
|
Value casPtr = ptrElements[0];
|
||||||
|
Value casCmp = cmpElements[0];
|
||||||
|
Value casVal = valElements[0];
|
||||||
|
|
||||||
|
PTXBuilder ptxBuilderAtomicCAS;
|
||||||
|
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r");
|
||||||
|
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
|
||||||
|
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
|
||||||
|
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
|
||||||
|
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
|
||||||
|
atom.global().o("cas").o("b32");
|
||||||
|
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred);
|
||||||
|
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
PTXBuilder ptxBuilderStore;
|
||||||
|
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "l");
|
||||||
|
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
|
||||||
|
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
|
||||||
|
st.shared().o("b32");
|
||||||
|
st(dstOprStore, valOprStore).predicate(pred);
|
||||||
|
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
|
||||||
|
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||||
|
barrier();
|
||||||
|
Value ret = load(atomPtr);
|
||||||
|
barrier();
|
||||||
|
rewriter.replaceOp(op, {ret});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
/// ====================== atomic_cas codegen end ==========================
|
||||||
|
|
||||||
/// ====================== atomic_rmw codegen begin ==========================
|
/// ====================== atomic_rmw codegen begin ==========================
|
||||||
struct AtomicRMWOpConversion
|
struct AtomicRMWOpConversion
|
||||||
@@ -6105,15 +6181,15 @@ struct AtomicRMWOpConversion
|
|||||||
Value rmwMask = maskElements[i];
|
Value rmwMask = maskElements[i];
|
||||||
rmwMask = and_(rmwMask, mask);
|
rmwMask = and_(rmwMask, mask);
|
||||||
std::string sTy;
|
std::string sTy;
|
||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilderAtomicRMW;
|
||||||
std::string tyId = valueElemNbits * vec == 64
|
std::string tyId = valueElemNbits * vec == 64
|
||||||
? "l"
|
? "l"
|
||||||
: (valueElemNbits * vec == 32 ? "r" : "h");
|
: (valueElemNbits * vec == 32 ? "r" : "h");
|
||||||
auto *dstOpr = ptxBuilder.newOperand("=" + tyId);
|
auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId);
|
||||||
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "l");
|
auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l");
|
||||||
auto *valOpr = ptxBuilder.newOperand(rmwVal, tyId);
|
auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId);
|
||||||
|
|
||||||
auto &atom = ptxBuilder.create<>("atom")->global().o("gpu");
|
auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o("gpu");
|
||||||
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
||||||
auto sBits = std::to_string(valueElemNbits);
|
auto sBits = std::to_string(valueElemNbits);
|
||||||
switch (atomicRmwAttr) {
|
switch (atomicRmwAttr) {
|
||||||
@@ -6149,21 +6225,29 @@ struct AtomicRMWOpConversion
|
|||||||
rmwOp = "min";
|
rmwOp = "min";
|
||||||
sTy = "u" + sBits;
|
sTy = "u" + sBits;
|
||||||
break;
|
break;
|
||||||
|
case RMWOp::XCHG:
|
||||||
|
sTy = "b" + sBits;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
atom.o(rmwOp).o(sTy);
|
atom.o(rmwOp).o(sTy);
|
||||||
if (valueTy) {
|
if (valueTy) {
|
||||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||||
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
|
||||||
for (int ii = 0; ii < vec; ++ii) {
|
for (int ii = 0; ii < vec; ++ii) {
|
||||||
resultVals[i * vec + ii] =
|
resultVals[i * vec + ii] =
|
||||||
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
|
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
PTXBuilder ptxBuilderMemfence;
|
||||||
|
auto memfenc = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
|
||||||
|
memfenc();
|
||||||
|
auto ASMReturnTy = void_ty(ctx);
|
||||||
|
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||||
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
|
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
|
||||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||||
auto old = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
|
||||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||||
store(old, atomPtr);
|
store(old, atomPtr);
|
||||||
@@ -6264,6 +6348,8 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
||||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
benefit);
|
||||||
|
patterns.add<AtomicCASOpConversion>(typeConverter, allocation, smem,
|
||||||
|
axisInfoAnalysis, benefit);
|
||||||
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
|
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
|
||||||
axisInfoAnalysis, benefit);
|
axisInfoAnalysis, benefit);
|
||||||
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
||||||
|
@@ -278,6 +278,20 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TritonAtomicCASPattern
|
||||||
|
: public OpConversionPattern<triton::AtomicCASOp> {
|
||||||
|
using OpConversionPattern<triton::AtomicCASOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
|
||||||
|
op, typeConverter->convertType(op.getType()),
|
||||||
|
adaptor.ptr(), adaptor.cmp(), adaptor.val());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct TritonAtomicRMWPattern
|
struct TritonAtomicRMWPattern
|
||||||
: public OpConversionPattern<triton::AtomicRMWOp> {
|
: public OpConversionPattern<triton::AtomicRMWOp> {
|
||||||
using OpConversionPattern<triton::AtomicRMWOp>::OpConversionPattern;
|
using OpConversionPattern<triton::AtomicRMWOp>::OpConversionPattern;
|
||||||
|
@@ -105,7 +105,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.value("AND", mlir::triton::RMWOp::AND)
|
.value("AND", mlir::triton::RMWOp::AND)
|
||||||
.value("OR", mlir::triton::RMWOp::OR)
|
.value("OR", mlir::triton::RMWOp::OR)
|
||||||
.value("XOR", mlir::triton::RMWOp::XOR)
|
.value("XOR", mlir::triton::RMWOp::XOR)
|
||||||
// .value("XCHG", mlir::triton::RMWOp::Xchg)
|
.value("XCHG", mlir::triton::RMWOp::XCHG)
|
||||||
.value("MAX", mlir::triton::RMWOp::MAX)
|
.value("MAX", mlir::triton::RMWOp::MAX)
|
||||||
.value("MIN", mlir::triton::RMWOp::MIN)
|
.value("MIN", mlir::triton::RMWOp::MIN)
|
||||||
.value("UMIN", mlir::triton::RMWOp::UMIN)
|
.value("UMIN", mlir::triton::RMWOp::UMIN)
|
||||||
@@ -1095,9 +1095,18 @@ void init_triton_ir(py::module &&m) {
|
|||||||
[](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
|
[](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
|
||||||
mlir::Value &val) -> mlir::Value {
|
mlir::Value &val) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
auto ptrType = mlir::getElementTypeOrSelf(ptr)
|
mlir::Type dstType;
|
||||||
.cast<mlir::triton::PointerType>();
|
if (auto srcTensorType = ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
||||||
mlir::Type dstType = ptrType.getPointeeType();
|
mlir::Type dstElemType = srcTensorType.getElementType()
|
||||||
|
.cast<mlir::triton::PointerType>()
|
||||||
|
.getPointeeType();
|
||||||
|
dstType = mlir::RankedTensorType::get(srcTensorType.getShape(),
|
||||||
|
dstElemType);
|
||||||
|
} else {
|
||||||
|
auto ptrType = mlir::getElementTypeOrSelf(ptr)
|
||||||
|
.cast<mlir::triton::PointerType>();
|
||||||
|
dstType = ptrType.getPointeeType();
|
||||||
|
}
|
||||||
return self.create<mlir::triton::AtomicCASOp>(loc, dstType, ptr,
|
return self.create<mlir::triton::AtomicCASOp>(loc, dstType, ptr,
|
||||||
cmp, val);
|
cmp, val);
|
||||||
})
|
})
|
||||||
|
@@ -700,6 +700,16 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
|
|||||||
# serialized_add[(64,)](data, Lock)
|
# serialized_add[(64,)](data, Lock)
|
||||||
# triton.testing.assert_almost_equal(data, ref)
|
# triton.testing.assert_almost_equal(data, ref)
|
||||||
|
|
||||||
|
def test_simple_atomic_cas():
|
||||||
|
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||||
|
@triton.jit
|
||||||
|
def change_value(Lock):
|
||||||
|
tl.atomic_cas(Lock, 0, 1)
|
||||||
|
|
||||||
|
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||||
|
change_value[(1,)](Lock)
|
||||||
|
|
||||||
|
assert (Lock[0] == 1)
|
||||||
|
|
||||||
# # ---------------
|
# # ---------------
|
||||||
# # test cast
|
# # test cast
|
||||||
|
Reference in New Issue
Block a user