#include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" #include "ConvertLayoutOpToLLVM.h" #include "LoadStoreOpToLLVM.h" using namespace mlir; using namespace mlir::triton; using ::mlir::LLVM::getElementsFromStruct; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::getStructFromElements; using ::mlir::triton::gpu::getElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass) : axisAnalysisPass(axisAnalysisPass) {} // Get corresponding LLVM element values of \param value. static SmallVector getLLVMElems(Value value, Value llValue, ConversionPatternRewriter &rewriter, Location loc) { if (!value) return {}; if (!llValue.getType().isa()) return {llValue}; // Here, we assume that all inputs should have a blockedLayout auto valueVals = getElementsFromStruct(loc, llValue, rewriter); return valueVals; } unsigned getVectorSize(Value ptr) const { return axisAnalysisPass.getPtrVectorSize(ptr); } unsigned getMaskAlignment(Value mask) const { return axisAnalysisPass.getMaskAlignment(mask); } protected: AxisInfoAnalysis &axisAnalysisPass; }; struct LoadOpConversion : public ConvertTritonGPUOpToLLVMPattern, public LoadStoreConversionBase { using ConvertTritonGPUOpToLLVMPattern< triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern; LoadOpConversion(LLVMTypeConverter &converter, AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); // original values Value ptr = op.ptr(); Value mask = op.mask(); Value other = op.other(); // adaptor values Value llPtr = adaptor.ptr(); Value llMask = adaptor.mask(); Value llOther = adaptor.other(); // Determine the vectorization size Type valueTy = op.getResult().getType(); Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(valueTy)); unsigned vec = getVectorSize(ptr); unsigned numElems = getElemsPerThread(ptr.getType()); if (llMask) vec = std::min(vec, getMaskAlignment(mask)); // Get the LLVM values for pointers auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc); assert(ptrElems.size() == numElems); // Get the LLVM values for mask SmallVector maskElems; if (llMask) { maskElems = getLLVMElems(mask, llMask, rewriter, loc); assert(maskElems.size() == numElems); } // Get the LLVM values for `other` // TODO: (goostavz) handle when other is const but not splat, which // should be rarely seen bool otherIsSplatConstInt = false; DenseElementsAttr constAttr; int64_t splatVal = 0; if (other && valueElemTy.isa() && matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat()) { otherIsSplatConstInt = true; splatVal = constAttr.getSplatValue().getSExtValue(); } auto otherElems = getLLVMElems(other, llOther, rewriter, loc); // vectorized iteration through all the pointer/mask/other elements const int valueElemNbits = std::max(8u, valueElemTy.getIntOrFloatBitWidth()); const int numVecs = numElems / vec; SmallVector loadedVals; for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { // TODO: optimization when ptr is GEP with constant offset size_t in_off = 0; const size_t maxWordWidth = std::max(32, valueElemNbits); const size_t totalWidth = valueElemNbits * vec; const size_t width = std::min(totalWidth, maxWordWidth); const size_t nWords = std::max(1, totalWidth / width); const size_t wordNElems = width / valueElemNbits; assert(wordNElems * nWords * numVecs == numElems); // TODO(Superjomn) Add cache policy fields to StoreOp. // TODO(Superjomn) Deal with cache policy here. const bool hasL2EvictPolicy = false; PTXBuilder ptxBuilder; Value pred = mask ? maskElems[vecStart] : int_val(1, 1); const std::string readConstraint = (width == 64) ? "l" : ((width == 32) ? "r" : "c"); const std::string writeConstraint = (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); // prepare asm operands auto *dstsOpr = ptxBuilder.newListOperand(); for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations dstsOpr->listAppend(opr); } auto *addrOpr = ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); // Define the instruction opcode auto &ld = ptxBuilder.create<>("ld") ->o("volatile", op.isVolatile()) .global() .o("ca", op.cache() == triton::CacheModifier::CA) .o("cg", op.cache() == triton::CacheModifier::CG) .o("L1::evict_first", op.evict() == triton::EvictionPolicy::EVICT_FIRST) .o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST) .o("L1::cache_hint", hasL2EvictPolicy) .v(nWords) .b(width); PTXBuilder::Operand *evictOpr{}; // Here lack a mlir::Value to bind to this operation, so disabled. // if (has_l2_evict_policy) // evictOpr = ptxBuilder.newOperand(l2Evict, "l"); if (!evictOpr) ld(dstsOpr, addrOpr).predicate(pred, "b"); else ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b"); if (other) { for (size_t ii = 0; ii < nWords; ++ii) { // PTX doesn't support mov.u8, so we need to use mov.u16 auto movWidth = width < 16 ? 16 : width; PTXInstr &mov = ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth)); size_t size = width / valueElemNbits; auto vecTy = LLVM::getFixedVectorType(valueElemTy, size); Value v = undef(vecTy); for (size_t s = 0; s < size; ++s) { Value falseVal = otherElems[vecStart + ii * size + s]; Value sVal = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), s); v = insert_element(vecTy, v, falseVal, sVal); } v = bitcast(v, IntegerType::get(getContext(), width)); PTXInstr::Operand *opr{}; if (otherIsSplatConstInt) opr = ptxBuilder.newConstantOperand(splatVal); else opr = ptxBuilder.newOperand(v, readConstraint); mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b"); } } // Create inline ASM signature SmallVector retTys(nWords, IntegerType::get(getContext(), width)); Type retTy = retTys.size() > 1 ? LLVM::LLVMStructType::getLiteral(getContext(), retTys) : retTys[0]; // TODO: if (has_l2_evict_policy) // auto asmDialectAttr = // LLVM::AsmDialectAttr::get(rewriter.getContext(), // LLVM::AsmDialect::AD_ATT); Value ret = ptxBuilder.launch(rewriter, loc, retTy); // Extract and store return values SmallVector rets; for (unsigned int ii = 0; ii < nWords; ++ii) { Value curr; if (retTy.isa()) { curr = extract_val(IntegerType::get(getContext(), width), ret, rewriter.getI64ArrayAttr(ii)); } else { curr = ret; } curr = bitcast(curr, LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits)); rets.push_back(curr); } int tmp = width / valueElemNbits; for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp); Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx); loadedVals.push_back(loaded); } } // end vec Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); Value resultStruct = getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy); rewriter.replaceOp(op, {resultStruct}); return success(); } }; struct StoreOpConversion : public ConvertTritonGPUOpToLLVMPattern, public LoadStoreConversionBase { using ConvertTritonGPUOpToLLVMPattern< triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern; StoreOpConversion(LLVMTypeConverter &converter, AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value ptr = op.ptr(); Value mask = op.mask(); Value value = op.value(); Value llPtr = adaptor.ptr(); Value llMask = adaptor.mask(); Value llValue = adaptor.value(); auto loc = op->getLoc(); MLIRContext *ctx = rewriter.getContext(); auto valueTy = value.getType(); Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(valueTy)); unsigned vec = getVectorSize(ptr); unsigned numElems = getElemsPerThread(ptr.getType()); auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc); auto valueElems = getLLVMElems(value, llValue, rewriter, loc); assert(ptrElems.size() == valueElems.size()); // Determine the vectorization size SmallVector maskElems; if (llMask) { maskElems = getLLVMElems(mask, llMask, rewriter, loc); assert(valueElems.size() == maskElems.size()); unsigned maskAlign = getMaskAlignment(mask); vec = std::min(vec, maskAlign); } const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); const size_t valueElemNbits = dtsize * 8; const int numVecs = numElems / vec; for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { // TODO: optimization when ptr is AddPtr with constant offset size_t in_off = 0; const size_t maxWordWidth = std::max(32, valueElemNbits); const size_t totalWidth = valueElemNbits * vec; const size_t width = std::min(totalWidth, maxWordWidth); const size_t nWords = std::max(1, totalWidth / width); const size_t wordNElems = width / valueElemNbits; assert(wordNElems * nWords * numVecs == numElems); // TODO(Superjomn) Add cache policy fields to StoreOp. // TODO(Superjomn) Deal with cache policy here. Type valArgTy = IntegerType::get(ctx, width); auto wordTy = vec_ty(valueElemTy, wordNElems); SmallVector> asmArgs; for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { // llWord is a width-len composition Value llWord = undef(wordTy); // Insert each value element to the composition for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) { const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx; assert(elemOffset < valueElems.size()); Value elem = valueElems[elemOffset]; if (elem.getType().isInteger(1)) elem = rewriter.create(loc, type::i8Ty(ctx), elem); elem = bitcast(elem, valueElemTy); Type u32Ty = typeConverter->convertType(type::u32Ty(ctx)); llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx)); } llWord = bitcast(llWord, valArgTy); std::string constraint = (width == 64) ? "l" : ((width == 32) ? "r" : "c"); asmArgs.emplace_back(llWord, constraint); } // Prepare the PTX inline asm. PTXBuilder ptxBuilder; auto *asmArgList = ptxBuilder.newListOperand(asmArgs); Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1); auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); auto &ptxStoreInstr = ptxBuilder.create<>("st")->global().v(nWords).b(width); ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b"); Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1)); llvm::SmallVector argTys({boolTy, ptr.getType()}); argTys.insert(argTys.end(), nWords, valArgTy); auto asmReturnTy = void_ty(ctx); ptxBuilder.launch(rewriter, loc, asmReturnTy); } rewriter.eraseOp(op); return success(); } }; struct AtomicCASOpConversion : public ConvertTritonGPUOpToLLVMPattern, public LoadStoreConversionBase { using ConvertTritonGPUOpToLLVMPattern< triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern; AtomicCASOpConversion(LLVMTypeConverter &converter, const Allocation *allocation, Value smem, AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( converter, allocation, smem, benefit), LoadStoreConversionBase(axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); MLIRContext *ctx = rewriter.getContext(); Value ptr = op.ptr(); Value llPtr = adaptor.ptr(); Value llCmp = adaptor.cmp(); Value llVal = adaptor.val(); auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter); auto cmpElements = getElementsFromStruct(loc, llCmp, rewriter); auto valElements = getElementsFromStruct(loc, llVal, rewriter); auto valueTy = op.getResult().getType().dyn_cast(); Type valueElemTy = valueTy ? getTypeConverter()->convertType(valueTy.getElementType()) : op.getResult().getType(); auto tid = tid_val(); Value pred = icmp_eq(tid, i32_val(0)); PTXBuilder ptxBuilderMemfence; auto memfence = ptxBuilderMemfence.create("membar")->o("gl"); memfence(); auto ASMReturnTy = void_ty(ctx); ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy); Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); Value casPtr = ptrElements[0]; Value casCmp = cmpElements[0]; Value casVal = valElements[0]; PTXBuilder ptxBuilderAtomicCAS; auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r"); auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l"); auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r"); auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r"); auto &atom = *ptxBuilderAtomicCAS.create("atom"); atom.global().o("cas").o("b32"); atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred); auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); barrier(); PTXBuilder ptxBuilderStore; auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "l"); auto *valOprStore = ptxBuilderStore.newOperand(old, "r"); auto &st = *ptxBuilderStore.create("st"); st.shared().o("b32"); st(dstOprStore, valOprStore).predicate(pred); ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy); barrier(); Value ret = load(atomPtr); barrier(); rewriter.replaceOp(op, {ret}); return success(); } }; struct AtomicRMWOpConversion : public ConvertTritonGPUOpToLLVMPattern, public LoadStoreConversionBase { using ConvertTritonGPUOpToLLVMPattern< triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern; AtomicRMWOpConversion(LLVMTypeConverter &converter, const Allocation *allocation, Value smem, AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( converter, allocation, smem, benefit), LoadStoreConversionBase(axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); MLIRContext *ctx = rewriter.getContext(); auto atomicRmwAttr = op.atomic_rmw_op(); Value ptr = op.ptr(); Value val = op.val(); Value llPtr = adaptor.ptr(); Value llVal = adaptor.val(); Value llMask = adaptor.mask(); auto valElements = getElementsFromStruct(loc, llVal, rewriter); auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter); auto maskElements = getElementsFromStruct(loc, llMask, rewriter); auto valueTy = op.getResult().getType().dyn_cast(); Type valueElemTy = valueTy ? getTypeConverter()->convertType(valueTy.getElementType()) : op.getResult().getType(); const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getElemsPerThread(val.getType()); // vec = 1 for scalar auto vec = getVectorSize(ptr); Value mask = int_val(1, 1); auto tid = tid_val(); // tensor if (valueTy) { auto valTy = val.getType().cast(); vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); // mask auto shape = valueTy.getShape(); auto numElements = product(shape); mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElements))); } auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); for (size_t i = 0; i < elemsPerThread; i += vec) { Value rmwVal = undef(vecTy); for (int ii = 0; ii < vec; ++ii) { Value iiVal = createIndexAttrConstant( rewriter, loc, getTypeConverter()->getIndexType(), ii); rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal); } Value rmwPtr = ptrElements[i]; Value rmwMask = maskElements[i]; rmwMask = and_(rmwMask, mask); std::string sTy; PTXBuilder ptxBuilderAtomicRMW; std::string tyId = valueElemNbits * vec == 64 ? "l" : (valueElemNbits * vec == 32 ? "r" : "h"); auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId); auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l"); auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o("gpu"); auto rmwOp = stringifyRMWOp(atomicRmwAttr).str(); auto sBits = std::to_string(valueElemNbits); switch (atomicRmwAttr) { case RMWOp::AND: sTy = "b" + sBits; break; case RMWOp::OR: sTy = "b" + sBits; break; case RMWOp::XOR: sTy = "b" + sBits; break; case RMWOp::ADD: sTy = "s" + sBits; break; case RMWOp::FADD: rmwOp = "add"; rmwOp += (valueElemNbits == 16 ? ".noftz" : ""); sTy = "f" + sBits; sTy += (vec == 2 && valueElemNbits == 16) ? "x2" : ""; break; case RMWOp::MAX: sTy = "s" + sBits; break; case RMWOp::MIN: sTy = "s" + sBits; break; case RMWOp::UMAX: rmwOp = "max"; sTy = "u" + sBits; break; case RMWOp::UMIN: rmwOp = "min"; sTy = "u" + sBits; break; case RMWOp::XCHG: sTy = "b" + sBits; break; default: return failure(); } atom.o(rmwOp).o(sTy); if (valueTy) { atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); auto retType = vec == 1 ? valueElemTy : vecTy; auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType); for (int ii = 0; ii < vec; ++ii) { resultVals[i + ii] = vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii)); } } else { PTXBuilder ptxBuilderMemfence; auto memfenc = ptxBuilderMemfence.create("membar")->o("gl"); memfenc(); auto ASMReturnTy = void_ty(ctx); ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy); rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0))); atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy); Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); store(old, atomPtr); barrier(); Value ret = load(atomPtr); barrier(); rewriter.replaceOp(op, {ret}); } } if (valueTy) { Type structTy = getTypeConverter()->convertType(valueTy); Value resultStruct = getStructFromElements(loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, {resultStruct}); } return success(); } }; struct InsertSliceOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< tensor::InsertSliceOp>::ConvertTritonGPUOpToLLVMPattern; LogicalResult matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // %dst = insert_slice %src into %dst[%offsets] Location loc = op->getLoc(); Value dst = op.dest(); Value src = op.source(); Value res = op.result(); assert(allocation->getBufferId(res) == Allocation::InvalidBufferId && "Only support in-place insert_slice for now"); auto srcTy = src.getType().dyn_cast(); auto srcLayout = srcTy.getEncoding().dyn_cast(); auto srcShape = srcTy.getShape(); assert(srcLayout && "Unexpected srcLayout in InsertSliceOpConversion"); auto dstTy = dst.getType().dyn_cast(); auto dstLayout = dstTy.getEncoding().dyn_cast(); auto llDst = adaptor.dest(); assert(dstLayout && "Unexpected dstLayout in InsertSliceOpConversion"); assert(op.hasUnitStride() && "Only unit stride supported by InsertSliceOpConversion"); // newBase = base + offset // Triton support either static and dynamic offsets auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter); SmallVector offsets; SmallVector srcStrides; auto mixedOffsets = op.getMixedOffsets(); for (auto i = 0; i < mixedOffsets.size(); ++i) { if (op.isDynamicOffset(i)) { offsets.emplace_back(adaptor.offsets()[i]); } else { offsets.emplace_back(i32_val(op.getStaticOffset(i))); } // Like insert_slice_async, we only support slice from one dimension, // which has a slice size of 1 if (op.getStaticSize(i) != 1) { srcStrides.emplace_back(smemObj.strides[i]); } } // Compute the offset based on the original strides of the shared memory // object auto offset = dot(rewriter, loc, offsets, smemObj.strides); auto elemTy = getTypeConverter()->convertType(dstTy.getElementType()); auto elemPtrTy = ptr_ty(elemTy, 3); auto smemBase = gep(elemPtrTy, smemObj.base, offset); auto llSrc = adaptor.source(); auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape); storeDistributedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase, elemTy, loc, rewriter); // Barrier is not necessary. // The membar pass knows that it writes to shared memory and will handle it // properly. rewriter.replaceOp(op, llDst); return success(); } }; struct InsertSliceAsyncOpConversion : public ConvertTritonGPUOpToLLVMPattern, public LoadStoreConversionBase { using ConvertTritonGPUOpToLLVMPattern< triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern; InsertSliceAsyncOpConversion( LLVMTypeConverter &converter, const Allocation *allocation, Value smem, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( converter, allocation, smem, indexCacheInfo, benefit), LoadStoreConversionBase(axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::gpu::InsertSliceAsyncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // insert_slice_async %src, %dst, %index, %mask, %other auto loc = op.getLoc(); Value src = op.src(); Value dst = op.dst(); Value res = op.result(); Value mask = op.mask(); Value other = op.other(); assert(allocation->getBufferId(res) == Allocation::InvalidBufferId && "Only support in-place insert_slice_async for now"); auto srcTy = src.getType().cast(); auto resTy = dst.getType().cast(); auto resElemTy = getTypeConverter()->convertType(resTy.getElementType()); auto srcBlockedLayout = srcTy.getEncoding().cast(); auto resSharedLayout = resTy.getEncoding().cast(); auto srcShape = srcTy.getShape(); assert(srcShape.size() == 2 && "insert_slice_async: Unexpected rank of %src"); Value llDst = adaptor.dst(); Value llSrc = adaptor.src(); Value llMask = adaptor.mask(); Value llOther = adaptor.other(); Value llIndex = adaptor.index(); // %src auto srcElems = getLLVMElems(src, llSrc, rewriter, loc); // %dst auto dstTy = dst.getType().cast(); auto dstShape = dstTy.getShape(); auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter); auto axis = op->getAttrOfType("axis").getInt(); SmallVector offsetVals; SmallVector srcStrides; for (auto i = 0; i < dstShape.size(); ++i) { if (i == axis) { offsetVals.emplace_back(llIndex); } else { offsetVals.emplace_back(i32_val(0)); srcStrides.emplace_back(smemObj.strides[i]); } } // Compute the offset based on the original dimensions of the shared // memory object auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); auto dstPtrTy = ptr_ty(resElemTy, 3); Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset); // %mask SmallVector maskElems; if (llMask) { maskElems = getLLVMElems(mask, llMask, rewriter, loc); assert(srcElems.size() == maskElems.size()); } // %other SmallVector otherElems; if (llOther) { // FIXME(Keren): always assume other is 0 for now // It's not necessary for now because the pipeline pass will skip // generating insert_slice_async if the load op has any "other" tensor. // assert(false && "insert_slice_async: Other value not supported yet"); otherElems = getLLVMElems(other, llOther, rewriter, loc); assert(srcElems.size() == otherElems.size()); } unsigned inVec = getVectorSize(src); unsigned outVec = resSharedLayout.getVec(); unsigned minVec = std::min(outVec, inVec); unsigned numElems = getElemsPerThread(srcTy); unsigned perPhase = resSharedLayout.getPerPhase(); unsigned maxPhase = resSharedLayout.getMaxPhase(); auto sizePerThread = srcBlockedLayout.getSizePerThread(); auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout); auto inOrder = srcBlockedLayout.getOrder(); // If perPhase * maxPhase > threadsPerCTA, we will have elements // that share the same tile indices. The index calculation will // be cached. auto numSwizzleRows = std::max( (perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1); // A sharedLayout encoding has a "vec" parameter. // On the column dimension, if inVec > outVec, it means we have to divide // single vector read into multiple ones auto numVecCols = std::max(inVec / outVec, 1); auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcShape); // <, TileOffset> DenseMap, Value> tileOffsetMap; for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) { // minVec = 2, inVec = 4, outVec = 2 // baseOffsetCol = 0 baseOffsetCol = 0 // tileVecIdxCol = 0 tileVecIdxCol = 1 // -/\- -/\- // [|x x| |x x| x x x x x] // [|x x| |x x| x x x x x] // baseOffsetRow [|x x| |x x| x x x x x] // [|x x| |x x| x x x x x] auto vecIdx = elemIdx / minVec; auto vecIdxCol = vecIdx % (sizePerThread[inOrder[0]] / minVec); auto vecIdxRow = vecIdx / (sizePerThread[inOrder[0]] / minVec); auto baseOffsetCol = vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]]; auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows * threadsPerCTA[inOrder[1]]; auto tileVecIdxCol = vecIdxCol % numVecCols; auto tileVecIdxRow = vecIdxRow % numSwizzleRows; if (!tileOffsetMap.count({tileVecIdxRow, tileVecIdxCol})) { // Swizzling // Since the swizzling index is related to outVec, and we know minVec // already, inVec doesn't matter // // (Numbers represent row indices) // Example1: // outVec = 2, inVec = 2, minVec = 2 // outVec = 2, inVec = 4, minVec = 2 // | [1 2] [3 4] [5 6] ... | // | [3 4] [1 2] [7 8] ... | // | [5 6] [7 8] [1 2] ... | // Example2: // outVec = 4, inVec = 2, minVec = 2 // | [1 2 3 4] [5 6 7 8] [9 10 11 12] ... | // | [5 6 7 8] [1 2 3 4] [13 14 15 16] ... | // | [9 10 11 12] [13 14 15 16] [1 2 3 4] ... | auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]]; Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)), i32_val(maxPhase)); // srcShape and smemObj.shape maybe different if smemObj is a // slice of the original shared memory object. // So we need to use the original shape to compute the offset Value rowOffset = mul(srcIdx[inOrder[1]], srcStrides[inOrder[1]]); Value colOffset = add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec)); Value swizzleIdx = udiv(colOffset, i32_val(outVec)); Value swizzleColOffset = add(mul(xor_(swizzleIdx, phase), i32_val(outVec)), urem(colOffset, i32_val(outVec))); Value tileOffset = add(rowOffset, swizzleColOffset); tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}] = gep(dstPtrTy, dstPtrBase, tileOffset); } // 16 * 8 = 128bits auto maxBitWidth = std::max(128, resElemTy.getIntOrFloatBitWidth()); auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec; auto bitWidth = std::min(maxBitWidth, vecBitWidth); auto numWords = vecBitWidth / bitWidth; auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth(); // Tune CG and CA here. auto byteWidth = bitWidth / 8; CacheModifier srcCacheModifier = byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA; assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4); auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8; Value tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}]; Value baseOffset = add(mul(i32_val(baseOffsetRow), srcStrides[inOrder[1]]), i32_val(baseOffsetCol)); Value basePtr = gep(dstPtrTy, tileOffset, baseOffset); for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) { PTXBuilder ptxBuilder; auto wordElemIdx = wordIdx * numWordElems; auto ©AsyncOp = *ptxBuilder.create(srcCacheModifier); auto *dstOperand = ptxBuilder.newAddrOperand(basePtr, "r", wordElemIdx * resByteWidth); auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l"); auto *copySize = ptxBuilder.newConstantOperand(byteWidth); auto *srcSize = copySize; if (op.mask()) { // We don't use predicate in this case, setting src-size to 0 // if there's any mask. cp.async will automatically fill the // remaining slots with 0 if cp-size > src-size. // XXX(Keren): Always assume other = 0 for now. auto selectOp = select(maskElems[elemIdx + wordElemIdx], i32_val(byteWidth), i32_val(0)); srcSize = ptxBuilder.newOperand(selectOp, "r"); } copyAsyncOp(dstOperand, srcOperand, copySize, srcSize); ptxBuilder.launch(rewriter, loc, void_ty(getContext())); } } PTXBuilder ptxBuilder; ptxBuilder.create<>("cp.async.commit_group")->operator()(); ptxBuilder.launch(rewriter, loc, void_ty(getContext())); rewriter.replaceOp(op, llDst); return success(); } }; void populateLoadStoreOpToLLVMPatterns( mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &axisInfoAnalysis, const Allocation *allocation, Value smem, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit) { patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, smem, axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, smem, axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, smem, indexCacheInfo, benefit); patterns.add(typeConverter, allocation, smem, indexCacheInfo, axisInfoAnalysis, benefit); }