diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index cdcb69576..bb0f4c067 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -59,7 +59,8 @@ def TT_AtomicRMWAttr : I32EnumAttr< I32EnumAttrCase<"MAX", 6, "max">, I32EnumAttrCase<"MIN", 7, "min">, I32EnumAttrCase<"UMAX", 8, "umax">, - I32EnumAttrCase<"UMIN", 9, "umin"> + I32EnumAttrCase<"UMIN", 9, "umin">, + I32EnumAttrCase<"XCHG", 10, "exch"> ]> { let cppNamespace = "::mlir::triton"; } diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index a155df8de..de0a6b52c 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -115,9 +115,9 @@ SmallVector getScratchConfigForReduce(triton::ReduceOp op) { // TODO: extend beyond scalars SmallVector getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) { SmallVector smemShape; - auto ptrTy = op.ptr().getType(); - if (auto tensorType = ptrTy.dyn_cast()) { - // do nothing or just assert because shared memory is not used in tensor + if (op.ptr().getType().isa()) { + // do nothing or just assert because shared memory is not used in tensor up + // to now } else { // need only bytes for scalar // always vec = 1 and elemsPerThread = 1 for scalar? @@ -126,6 +126,10 @@ SmallVector getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) { return smemShape; } +SmallVector getScratchConfigForAtomicCAS(triton::AtomicCASOp op) { + return SmallVector{1}; +} + class AllocationAnalysis { public: AllocationAnalysis(Operation *operation, Allocation *allocation) @@ -230,6 +234,17 @@ private: : elems * elemTy.getIntOrFloatBitWidth() / 8; allocation->addBuffer(op, bytes); } + } else if (auto atomicCASOp = dyn_cast(op)) { + auto value = op->getOperand(0); + auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto elemTy = + value.getType().cast().getPointeeType(); + auto bytes = elemTy.isa() + ? elems * kPtrBitWidth / 8 + : elems * elemTy.getIntOrFloatBitWidth() / 8; + allocation->addBuffer(op, bytes); } } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 471ab5b72..dfd31240b 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -6036,6 +6036,82 @@ struct ExpOpConversionApprox return ptxBuilder.launch(rewriter, loc, f32_ty, false); } }; +/// ====================== atomic_cas codegen begin ========================== +struct AtomicCASOpConversion + : public ConvertTritonGPUOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertTritonGPUOpToLLVMPattern< + triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern; + + AtomicCASOpConversion(LLVMTypeConverter &converter, + const Allocation *allocation, Value smem, + AxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern( + converter, allocation, smem, benefit), + LoadStoreConversionBase(axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + Value ptr = op.ptr(); + + Value llPtr = adaptor.ptr(); + Value llCmp = adaptor.cmp(); + Value llVal = adaptor.val(); + + auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter); + auto cmpElements = getElementsFromStruct(loc, llCmp, rewriter); + auto valElements = getElementsFromStruct(loc, llVal, rewriter); + + auto valueTy = op.getResult().getType().dyn_cast(); + Type valueElemTy = + valueTy ? getTypeConverter()->convertType(valueTy.getElementType()) + : op.getResult().getType(); + auto tid = tid_val(); + Value pred = icmp_eq(tid, i32_val(0)); + PTXBuilder ptxBuilderMemfence; + auto memfenc = ptxBuilderMemfence.create("membar")->o("gl"); + memfenc(); + auto ASMReturnTy = void_ty(ctx); + ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy); + + Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); + atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); + + Value casPtr = ptrElements[0]; + Value casCmp = cmpElements[0]; + Value casVal = valElements[0]; + + PTXBuilder ptxBuilderAtomicCAS; + auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r"); + auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l"); + auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r"); + auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r"); + auto &atom = *ptxBuilderAtomicCAS.create("atom"); + atom.global().o("cas").o("b32"); + atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred); + auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); + barrier(); + + PTXBuilder ptxBuilderStore; + auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "l"); + auto *valOprStore = ptxBuilderStore.newOperand(old, "r"); + auto &st = *ptxBuilderStore.create("st"); + st.shared().o("b32"); + st(dstOprStore, valOprStore).predicate(pred); + ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); + ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy); + barrier(); + Value ret = load(atomPtr); + barrier(); + rewriter.replaceOp(op, {ret}); + return success(); + } +}; +/// ====================== atomic_cas codegen end ========================== /// ====================== atomic_rmw codegen begin ========================== struct AtomicRMWOpConversion @@ -6105,15 +6181,15 @@ struct AtomicRMWOpConversion Value rmwMask = maskElements[i]; rmwMask = and_(rmwMask, mask); std::string sTy; - PTXBuilder ptxBuilder; + PTXBuilder ptxBuilderAtomicRMW; std::string tyId = valueElemNbits * vec == 64 ? "l" : (valueElemNbits * vec == 32 ? "r" : "h"); - auto *dstOpr = ptxBuilder.newOperand("=" + tyId); - auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "l"); - auto *valOpr = ptxBuilder.newOperand(rmwVal, tyId); + auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId); + auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l"); + auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); - auto &atom = ptxBuilder.create<>("atom")->global().o("gpu"); + auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o("gpu"); auto rmwOp = stringifyRMWOp(atomicRmwAttr).str(); auto sBits = std::to_string(valueElemNbits); switch (atomicRmwAttr) { @@ -6149,21 +6225,29 @@ struct AtomicRMWOpConversion rmwOp = "min"; sTy = "u" + sBits; break; + case RMWOp::XCHG: + sTy = "b" + sBits; + break; default: return failure(); } atom.o(rmwOp).o(sTy); if (valueTy) { atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); - auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy); + auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy); for (int ii = 0; ii < vec; ++ii) { resultVals[i * vec + ii] = vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii)); } } else { + PTXBuilder ptxBuilderMemfence; + auto memfenc = ptxBuilderMemfence.create("membar")->o("gl"); + memfenc(); + auto ASMReturnTy = void_ty(ctx); + ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy); rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0))); atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); - auto old = ptxBuilder.launch(rewriter, loc, valueElemTy); + auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy); Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); store(old, atomPtr); @@ -6264,6 +6348,8 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add(typeConverter, allocation, smem, benefit); patterns.add(typeConverter, allocation, smem, benefit); + patterns.add(typeConverter, allocation, smem, + axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, smem, axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, smem, diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index d4473f5b8..637acab64 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -278,6 +278,20 @@ struct TritonStorePattern : public OpConversionPattern { } }; +struct TritonAtomicCASPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(op.getType()), + adaptor.ptr(), adaptor.cmp(), adaptor.val()); + return success(); + } +}; + struct TritonAtomicRMWPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; diff --git a/python/src/triton.cc b/python/src/triton.cc index cfe092821..f06f7a476 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -105,7 +105,7 @@ void init_triton_ir(py::module &&m) { .value("AND", mlir::triton::RMWOp::AND) .value("OR", mlir::triton::RMWOp::OR) .value("XOR", mlir::triton::RMWOp::XOR) - // .value("XCHG", mlir::triton::RMWOp::Xchg) + .value("XCHG", mlir::triton::RMWOp::XCHG) .value("MAX", mlir::triton::RMWOp::MAX) .value("MIN", mlir::triton::RMWOp::MIN) .value("UMIN", mlir::triton::RMWOp::UMIN) @@ -1095,9 +1095,18 @@ void init_triton_ir(py::module &&m) { [](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp, mlir::Value &val) -> mlir::Value { auto loc = self.getUnknownLoc(); - auto ptrType = mlir::getElementTypeOrSelf(ptr) - .cast(); - mlir::Type dstType = ptrType.getPointeeType(); + mlir::Type dstType; + if (auto srcTensorType = ptr.getType().dyn_cast()) { + mlir::Type dstElemType = srcTensorType.getElementType() + .cast() + .getPointeeType(); + dstType = mlir::RankedTensorType::get(srcTensorType.getShape(), + dstElemType); + } else { + auto ptrType = mlir::getElementTypeOrSelf(ptr) + .cast(); + dstType = ptrType.getPointeeType(); + } return self.create(loc, dstType, ptr, cmp, val); }) diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 21f9df750..cbd6e046b 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -700,6 +700,16 @@ def test_tensor_atomic_rmw(axis, device="cuda"): # serialized_add[(64,)](data, Lock) # triton.testing.assert_almost_equal(data, ref) +def test_simple_atomic_cas(): + # 1. make sure that atomic_cas changes the original value (Lock) + @triton.jit + def change_value(Lock): + tl.atomic_cas(Lock, 0, 1) + + Lock = torch.zeros((1,), device='cuda', dtype=torch.int32) + change_value[(1,)](Lock) + + assert (Lock[0] == 1) # # --------------- # # test cast