[TRITON-MLIR][BACKEND]support atomic_cas (#914)

1. support atomics-cas
2. add xchg support in atomic_rmw

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-11-25 12:02:08 +08:00
committed by GitHub
parent 153aecb339
commit f63be0e9b5
6 changed files with 150 additions and 15 deletions

View File

@@ -59,7 +59,8 @@ def TT_AtomicRMWAttr : I32EnumAttr<
I32EnumAttrCase<"MAX", 6, "max">, I32EnumAttrCase<"MAX", 6, "max">,
I32EnumAttrCase<"MIN", 7, "min">, I32EnumAttrCase<"MIN", 7, "min">,
I32EnumAttrCase<"UMAX", 8, "umax">, I32EnumAttrCase<"UMAX", 8, "umax">,
I32EnumAttrCase<"UMIN", 9, "umin"> I32EnumAttrCase<"UMIN", 9, "umin">,
I32EnumAttrCase<"XCHG", 10, "exch">
]> { ]> {
let cppNamespace = "::mlir::triton"; let cppNamespace = "::mlir::triton";
} }

View File

@@ -115,9 +115,9 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
// TODO: extend beyond scalars // TODO: extend beyond scalars
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) { SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
SmallVector<unsigned> smemShape; SmallVector<unsigned> smemShape;
auto ptrTy = op.ptr().getType(); if (op.ptr().getType().isa<RankedTensorType>()) {
if (auto tensorType = ptrTy.dyn_cast<RankedTensorType>()) { // do nothing or just assert because shared memory is not used in tensor up
// do nothing or just assert because shared memory is not used in tensor // to now
} else { } else {
// need only bytes for scalar // need only bytes for scalar
// always vec = 1 and elemsPerThread = 1 for scalar? // always vec = 1 and elemsPerThread = 1 for scalar?
@@ -126,6 +126,10 @@ SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
return smemShape; return smemShape;
} }
SmallVector<unsigned> getScratchConfigForAtomicCAS(triton::AtomicCASOp op) {
return SmallVector<unsigned>{1};
}
class AllocationAnalysis { class AllocationAnalysis {
public: public:
AllocationAnalysis(Operation *operation, Allocation *allocation) AllocationAnalysis(Operation *operation, Allocation *allocation)
@@ -230,6 +234,17 @@ private:
: elems * elemTy.getIntOrFloatBitWidth() / 8; : elems * elemTy.getIntOrFloatBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes); allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} }
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
auto value = op->getOperand(0);
auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto elemTy =
value.getType().cast<triton::PointerType>().getPointeeType();
auto bytes = elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} }
} }

View File

@@ -6036,6 +6036,82 @@ struct ExpOpConversionApprox
return ptxBuilder.launch(rewriter, loc, f32_ty, false); return ptxBuilder.launch(rewriter, loc, f32_ty, false);
} }
}; };
/// ====================== atomic_cas codegen begin ==========================
struct AtomicCASOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicCASOpConversion(LLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>(
converter, allocation, smem, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
Value ptr = op.ptr();
Value llPtr = adaptor.ptr();
Value llCmp = adaptor.cmp();
Value llVal = adaptor.val();
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
auto cmpElements = getElementsFromStruct(loc, llCmp, rewriter);
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
Type valueElemTy =
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
: op.getResult().getType();
auto tid = tid_val();
Value pred = icmp_eq(tid, i32_val(0));
PTXBuilder ptxBuilderMemfence;
auto memfenc = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfenc();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
Value casPtr = ptrElements[0];
Value casCmp = cmpElements[0];
Value casVal = valElements[0];
PTXBuilder ptxBuilderAtomicCAS;
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r");
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
atom.global().o("cas").o("b32");
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred);
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
barrier();
PTXBuilder ptxBuilderStore;
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "l");
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
st.shared().o("b32");
st(dstOprStore, valOprStore).predicate(pred);
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
barrier();
Value ret = load(atomPtr);
barrier();
rewriter.replaceOp(op, {ret});
return success();
}
};
/// ====================== atomic_cas codegen end ==========================
/// ====================== atomic_rmw codegen begin ========================== /// ====================== atomic_rmw codegen begin ==========================
struct AtomicRMWOpConversion struct AtomicRMWOpConversion
@@ -6105,15 +6181,15 @@ struct AtomicRMWOpConversion
Value rmwMask = maskElements[i]; Value rmwMask = maskElements[i];
rmwMask = and_(rmwMask, mask); rmwMask = and_(rmwMask, mask);
std::string sTy; std::string sTy;
PTXBuilder ptxBuilder; PTXBuilder ptxBuilderAtomicRMW;
std::string tyId = valueElemNbits * vec == 64 std::string tyId = valueElemNbits * vec == 64
? "l" ? "l"
: (valueElemNbits * vec == 32 ? "r" : "h"); : (valueElemNbits * vec == 32 ? "r" : "h");
auto *dstOpr = ptxBuilder.newOperand("=" + tyId); auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId);
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "l"); auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l");
auto *valOpr = ptxBuilder.newOperand(rmwVal, tyId); auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId);
auto &atom = ptxBuilder.create<>("atom")->global().o("gpu"); auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o("gpu");
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str(); auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
auto sBits = std::to_string(valueElemNbits); auto sBits = std::to_string(valueElemNbits);
switch (atomicRmwAttr) { switch (atomicRmwAttr) {
@@ -6149,21 +6225,29 @@ struct AtomicRMWOpConversion
rmwOp = "min"; rmwOp = "min";
sTy = "u" + sBits; sTy = "u" + sBits;
break; break;
case RMWOp::XCHG:
sTy = "b" + sBits;
break;
default: default:
return failure(); return failure();
} }
atom.o(rmwOp).o(sTy); atom.o(rmwOp).o(sTy);
if (valueTy) { if (valueTy) {
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy); auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
for (int ii = 0; ii < vec; ++ii) { for (int ii = 0; ii < vec; ++ii) {
resultVals[i * vec + ii] = resultVals[i * vec + ii] =
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii)); vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
} }
} else { } else {
PTXBuilder ptxBuilderMemfence;
auto memfenc = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfenc();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0))); rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto old = ptxBuilder.launch(rewriter, loc, valueElemTy); auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
store(old, atomPtr); store(old, atomPtr);
@@ -6264,6 +6348,8 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit); patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem, patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
benefit); benefit);
patterns.add<AtomicCASOpConversion>(typeConverter, allocation, smem,
axisInfoAnalysis, benefit);
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem, patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
axisInfoAnalysis, benefit); axisInfoAnalysis, benefit);
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem, patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,

View File

@@ -278,6 +278,20 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
} }
}; };
struct TritonAtomicCASPattern
: public OpConversionPattern<triton::AtomicCASOp> {
using OpConversionPattern<triton::AtomicCASOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
op, typeConverter->convertType(op.getType()),
adaptor.ptr(), adaptor.cmp(), adaptor.val());
return success();
}
};
struct TritonAtomicRMWPattern struct TritonAtomicRMWPattern
: public OpConversionPattern<triton::AtomicRMWOp> { : public OpConversionPattern<triton::AtomicRMWOp> {
using OpConversionPattern<triton::AtomicRMWOp>::OpConversionPattern; using OpConversionPattern<triton::AtomicRMWOp>::OpConversionPattern;

View File

@@ -105,7 +105,7 @@ void init_triton_ir(py::module &&m) {
.value("AND", mlir::triton::RMWOp::AND) .value("AND", mlir::triton::RMWOp::AND)
.value("OR", mlir::triton::RMWOp::OR) .value("OR", mlir::triton::RMWOp::OR)
.value("XOR", mlir::triton::RMWOp::XOR) .value("XOR", mlir::triton::RMWOp::XOR)
// .value("XCHG", mlir::triton::RMWOp::Xchg) .value("XCHG", mlir::triton::RMWOp::XCHG)
.value("MAX", mlir::triton::RMWOp::MAX) .value("MAX", mlir::triton::RMWOp::MAX)
.value("MIN", mlir::triton::RMWOp::MIN) .value("MIN", mlir::triton::RMWOp::MIN)
.value("UMIN", mlir::triton::RMWOp::UMIN) .value("UMIN", mlir::triton::RMWOp::UMIN)
@@ -1095,9 +1095,18 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp, [](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
mlir::Value &val) -> mlir::Value { mlir::Value &val) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
auto ptrType = mlir::getElementTypeOrSelf(ptr) mlir::Type dstType;
.cast<mlir::triton::PointerType>(); if (auto srcTensorType = ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
mlir::Type dstType = ptrType.getPointeeType(); mlir::Type dstElemType = srcTensorType.getElementType()
.cast<mlir::triton::PointerType>()
.getPointeeType();
dstType = mlir::RankedTensorType::get(srcTensorType.getShape(),
dstElemType);
} else {
auto ptrType = mlir::getElementTypeOrSelf(ptr)
.cast<mlir::triton::PointerType>();
dstType = ptrType.getPointeeType();
}
return self.create<mlir::triton::AtomicCASOp>(loc, dstType, ptr, return self.create<mlir::triton::AtomicCASOp>(loc, dstType, ptr,
cmp, val); cmp, val);
}) })

View File

@@ -700,6 +700,16 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
# serialized_add[(64,)](data, Lock) # serialized_add[(64,)](data, Lock)
# triton.testing.assert_almost_equal(data, ref) # triton.testing.assert_almost_equal(data, ref)
def test_simple_atomic_cas():
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
def change_value(Lock):
tl.atomic_cas(Lock, 0, 1)
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
change_value[(1,)](Lock)
assert (Lock[0] == 1)
# # --------------- # # ---------------
# # test cast # # test cast