diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 0d0706eab..7c205fe0c 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -209,6 +209,33 @@ ChangeResult AxisInfoAnalysis::visitOperation( } curr = AxisInfo(contiguity, divisibility, constancy); } + + // CmpI + if ((llvm::dyn_cast(op) || + llvm::dyn_cast(op)) && + op->getResult(0).getType().dyn_cast()) { + auto resTy = op->getResult(0).getType().cast(); + short rank = resTy.getRank(); + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto shape = resTy.getShape(); + + AxisInfo::DimVectorT contiguity, divisibility, constancy; + for (short d = 0; d < rank; ++d) { + if (rhsInfo.getConstancy(d) % lhsInfo.getContiguity(d) == 0 || + rhsInfo.getConstancy(d) % lhsInfo.getConstancy(d)) + constancy.push_back( + gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + else + constancy.push_back(1); + + divisibility.push_back(shape[d]); + contiguity.push_back(1); + } + + curr = AxisInfo(contiguity, divisibility, constancy); + } + // UnrealizedConversionCast // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is // in the process of a PartialConversion, where UnrealizedConversionCast @@ -219,7 +246,8 @@ ChangeResult AxisInfoAnalysis::visitOperation( if (curr.getRank() == 0) { return markAllPessimisticFixpoint(op->getResults()); } - // join all latice elements + + // join all lattice elements ChangeResult result = ChangeResult::NoChange; for (Value value : op->getResults()) { result |= getLatticeElement(value).join(curr); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 2820e03c0..56499f3f8 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -759,6 +759,17 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { return vec; } + unsigned getMaskAlignment(Value mask) const { + auto maskOrder = mask.getType() + .cast() + .getEncoding() + .cast() + .getOrder(); + + auto maskAxis = getAxisInfo(mask); + return std::max(maskAxis->getConstancy(maskOrder[0]), 1); + } + llvm::Optional getAxisInfo(Value val) const { if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) { return it->getValue(); @@ -771,6 +782,208 @@ protected: AxisInfoAnalysis &AxisAnalysisPass; }; +struct LoadOpConversion + : public ConvertTritonGPUOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertTritonGPUOpToLLVMPattern< + triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern; + + LoadOpConversion(LLVMTypeConverter &converter, + AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = op.ptr(); + Value mask = op.mask(); + Value other = op.other(); + + Value llPtr = adaptor.ptr(); + Value llMask = adaptor.mask(); + Value llOther = adaptor.other(); + + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + auto valueTy = op.getResult().getType().dyn_cast(); + if (!valueTy) + return failure(); + Type valueElemTy = + getTypeConverter()->convertType(valueTy.getElementType()); + + auto [layout, numElems] = getLayout(ptr); + + auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc); + assert(ptrElems.size() == numElems); + // Determine the vectorization size + size_t vec = getVectorizeSize(ptr, layout); + + SmallVector maskElems; + if (llMask) { + unsigned maskAlignment = getMaskAlignment(mask); + maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc); + assert(ptrElems.size() == maskElems.size()); + + size_t maskAlign = getMaskAlignment(mask); + vec = std::min(vec, maskAlign); + } + + const size_t dtsize = + std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); + const size_t valueElemNbits = dtsize * 8; + + const int numVecs = numElems / vec; + + // TODO: (goostavz) handle when other is const but not splat, which + // should be rarely seen + bool otherIsSplatConstInt = false; + DenseElementsAttr constAttr; + int64_t splatVal = 0; + if (valueElemTy.isa() && + matchPattern(op.other(), m_Constant(&constAttr)) && + constAttr.isSplat()) { + otherIsSplatConstInt = true; + splatVal = constAttr.getSplatValue().getSExtValue(); + } + + auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc); + + SmallVector loadedVals; + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + // TODO: optimization when ptr is GEP with constant offset + size_t in_off = 0; + + const int maxWordWidth = std::max(32, valueElemNbits); + const int totalWidth = valueElemNbits * vec; + const int width = std::min(totalWidth, maxWordWidth); + const int nWords = std::max(1, totalWidth / width); + const int wordNElems = width / valueElemNbits; + const int vecNElems = totalWidth / valueElemNbits; + assert(wordNElems * nWords * numVecs == numElems); + + // TODO(Superjomn) Add cache policy fields to StoreOp. + // TODO(Superjomn) Deal with cache policy here. + const bool hasL2EvictPolicy = false; + + PTXBuilder ptxBuilder; + auto &ld = *ptxBuilder.create("ld"); + + Value pred = mask ? maskElems[vecStart] : int_val(1, 1); + + const std::string readConstraint = + (width == 64) ? "l" : ((width == 32) ? "r" : "c"); + const std::string writeConstraint = + (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); + + // prepare asm operands + auto *dstsOpr = ptxBuilder.newListOperand(); + for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) { + auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations + dstsOpr->listAppend(opr); + } + + auto *addrOpr = + ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); + + // Define the instruction opcode + ld.o("volatile", op.isVolatile()) + .global() + .o("ca", op.cache() == triton::CacheModifier::CA) + .o("cg", op.cache() == triton::CacheModifier::CG) + .o("L1::evict_first", + op.evict() == triton::EvictionPolicy::EVICT_FIRST) + .o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST) + .o("L1::cache_hint", hasL2EvictPolicy) + .v(nWords) + .b(width); + + PTXBuilder::Operand *evictOpr{}; + + // Here lack a mlir::Value to bind to this operation, so disabled. + // if (has_l2_evict_policy) + // evictOpr = ptxBuilder.newOperand(l2Evict, "l"); + + if (!evictOpr) + ld(dstsOpr, addrOpr).predicate(pred, "b"); + else + ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b"); + + if (other) { + for (size_t ii = 0; ii < nWords; ++ii) { + PTXInstr &mov = *ptxBuilder.create<>("mov"); + mov.o("u", width); + + size_t size = width / valueElemNbits; + + auto vecTy = LLVM::getFixedVectorType(valueElemTy, size); + Value v = rewriter.create(loc, vecTy); + for (size_t s = 0; s < size; ++s) { + Value falseVal = otherElems[vecStart + ii * size + s]; + Value sVal = createIndexAttrConstant( + rewriter, loc, this->getTypeConverter()->getIndexType(), s); + v = insert_element(vecTy, v, falseVal, sVal); + } + v = bitcast(IntegerType::get(getContext(), width), v); + + PTXInstr::Operand *opr{}; + if (otherIsSplatConstInt) + opr = ptxBuilder.newConstantOperand(splatVal); + else + opr = ptxBuilder.newOperand(v, readConstraint); + + mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b"); + } + } + + // --- + // create inline ASM signature + // --- + SmallVector retTys(nWords, IntegerType::get(getContext(), width)); + Type retTy = retTys.size() > 1 + ? LLVM::LLVMStructType::getLiteral(getContext(), retTys) + : retTys[0]; + + // TODO: if (has_l2_evict_policy) + auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), + LLVM::AsmDialect::AD_ATT); + Value ret = ptxBuilder.launch(rewriter, loc, retTy); + + // --- + // extract and store return values + // --- + SmallVector rets; + for (unsigned int ii = 0; ii < nWords; ++ii) { + Value curr; + if (retTy.isa()) { + curr = extract_val(IntegerType::get(getContext(), width), ret, + rewriter.getI64ArrayAttr(ii)); + } else { + curr = ret; + } + curr = bitcast( + LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits), + curr); + rets.push_back(curr); + } + int tmp = width / valueElemNbits; + for (size_t ii = 0; ii < vec; ++ii) { + Value vecIdx = createIndexAttrConstant( + rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp); + Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx); + loadedVals.push_back(loaded); + } + } // end vec + + Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); + Value resultStruct = + getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + struct StoreOpConversion : public ConvertTritonGPUOpToLLVMPattern, public LoadStoreConversionBase { @@ -814,14 +1027,8 @@ struct StoreOpConversion if (llMask) { maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc); assert(valueElems.size() == maskElems.size()); - auto maskOrder = mask.getType() - .cast() - .getEncoding() - .cast() - .getOrder(); - auto maskAxis = getAxisInfo(mask); - size_t maskAlign = std::max(maskAxis->getConstancy(maskOrder[0]), 1); + size_t maskAlign = getMaskAlignment(mask); vec = std::min(vec, maskAlign); } @@ -846,15 +1053,10 @@ struct StoreOpConversion // TODO(Superjomn) Deal with cache policy here. const bool hasL2EvictPolicy = false; - PTXBuilder ptxBuilder; - auto &ptxStoreInstr = *ptxBuilder.create("st"); - - llvm::SmallVector asmArgs; - Type valArgTy = IntegerType::get(ctx, width); auto wordTy = vec_ty(valueElemTy, wordNElems); - auto *asmArgList = ptxBuilder.newListOperand(); + SmallVector> asmArgs; for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) { // llWord is a width-len composition Value llWord = rewriter.create(loc, wordTy); @@ -876,23 +1078,25 @@ struct StoreOpConversion llWord = bitcast(valArgTy, llWord); std::string constraint = (width == 64) ? "l" : ((width == 32) ? "r" : "c"); - asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint)); + asmArgs.emplace_back(llWord, constraint); } - // TODO(Superjomn) Need to check masks before vectorize the load for - // the values share one predicate? Here assume all the mask values are - // the same. + // Prepare the PTX inline asm. + PTXBuilder ptxBuilder; + auto *asmArgList = ptxBuilder.newListOperand(asmArgs); + Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1); - ptxStoreInstr.global().b(width).v(nWords); auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); + auto &ptxStoreInstr = + ptxBuilder.create("st")->global().b(width).v(nWords); ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b"); + Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1)); llvm::SmallVector argTys({boolTy, ptr.getType()}); - for (int i = 0; i < nWords; ++i) - argTys.push_back(valArgTy); + argTys.insert(argTys.end(), nWords, valArgTy); auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx); @@ -1065,209 +1269,6 @@ struct MakeRangeOpConversion } }; -struct LoadOpConversion - : public ConvertTritonGPUOpToLLVMPattern, - public LoadStoreConversionBase { - using ConvertTritonGPUOpToLLVMPattern< - triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern; - - LoadOpConversion(LLVMTypeConverter &converter, - AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) - : ConvertTritonGPUOpToLLVMPattern(converter, benefit), - LoadStoreConversionBase(axisAnalysisPass) {} - - LogicalResult - matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = op.ptr(); - Value mask = op.mask(); - Value other = op.other(); - - Value llPtr = adaptor.ptr(); - Value llMask = adaptor.mask(); - Value llOther = adaptor.other(); - - auto loc = op->getLoc(); - MLIRContext *ctx = rewriter.getContext(); - - auto valueTy = op.getResult().getType().dyn_cast(); - if (!valueTy) - return failure(); - Type valueElemTy = - getTypeConverter()->convertType(valueTy.getElementType()); - - auto [layout, numElems] = getLayout(ptr); - - auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc); - assert(ptrElems.size() == numElems); - - SmallVector maskElems; - if (llMask) { - maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc); - assert(ptrElems.size() == maskElems.size()); - } - - // Determine the vectorization size - size_t vec = getVectorizeSize(ptr, layout); - - const size_t dtsize = - std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); - const size_t valueElemNbits = dtsize * 8; - - const int numVecs = numElems / vec; - - // TODO: (goostavz) handle when other is const but not splat, which - // should be rarely seen - bool otherIsSplatConstInt = false; - DenseElementsAttr constAttr; - int64_t splatVal = 0; - if (valueElemTy.isa() && - matchPattern(op.other(), m_Constant(&constAttr)) && - constAttr.isSplat()) { - otherIsSplatConstInt = true; - splatVal = constAttr.getSplatValue().getSExtValue(); - } - - auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc); - - SmallVector loadedVals; - for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { - // TODO: optimization when ptr is GEP with constant offset - size_t in_off = 0; - - const int maxWordWidth = std::max(32, valueElemNbits); - const int totalWidth = valueElemNbits * vec; - const int width = std::min(totalWidth, maxWordWidth); - const int nWords = std::max(1, totalWidth / width); - const int wordNElems = width / valueElemNbits; - const int vecNElems = totalWidth / valueElemNbits; - assert(wordNElems * nWords * numVecs == numElems); - - // TODO(Superjomn) Add cache policy fields to StoreOp. - // TODO(Superjomn) Deal with cache policy here. - const bool hasL2EvictPolicy = false; - - PTXBuilder ptxBuilder; - auto &ld = *ptxBuilder.create("ld"); - - // TODO(Superjomn) Need to check masks before vectorize the load for all - // the values share one predicate? Here assume all the mask values are - // the same. - Value pred = mask ? maskElems[vecStart] : int_val(1, 1); - - const std::string readConstraint = - (width == 64) ? "l" : ((width == 32) ? "r" : "c"); - const std::string writeConstraint = - (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); - - // prepare asm operands - auto *dstsOpr = ptxBuilder.newListOperand(); - for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) { - auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations - dstsOpr->listAppend(opr); - } - - auto *addrOpr = - ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); - - // Define the instruction opcode - ld.o("volatile", op.isVolatile()) - .global() - .o("ca", op.cache() == triton::CacheModifier::CA) - .o("cg", op.cache() == triton::CacheModifier::CG) - .o("L1::evict_first", - op.evict() == triton::EvictionPolicy::EVICT_FIRST) - .o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST) - .o("L1::cache_hint", hasL2EvictPolicy) - .v(nWords) - .b(width); - - PTXBuilder::Operand *evictOpr{}; - - // Here lack a mlir::Value to bind to this operation, so disabled. - // if (has_l2_evict_policy) - // evictOpr = ptxBuilder.newOperand(l2Evict, "l"); - - if (!evictOpr) - ld(dstsOpr, addrOpr).predicate(pred, "b"); - else - ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b"); - - if (other) { - for (size_t ii = 0; ii < nWords; ++ii) { - PTXInstr &mov = *ptxBuilder.create<>("mov"); - mov.o("u", width); - - size_t size = width / valueElemNbits; - - auto vecTy = LLVM::getFixedVectorType(valueElemTy, size); - Value v = rewriter.create(loc, vecTy); - for (size_t s = 0; s < size; ++s) { - Value falseVal = otherElems[vecStart + ii * size + s]; - Value sVal = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), s); - v = insert_element(vecTy, v, falseVal, sVal); - } - v = bitcast(IntegerType::get(getContext(), width), v); - - PTXInstr::Operand *opr{}; - if (otherIsSplatConstInt) { - opr = ptxBuilder.newConstantOperand(splatVal); - } else { - opr = ptxBuilder.newOperand(v, readConstraint); - } - - mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b"); - } - } - - // --- - // create inline ASM signature - // --- - SmallVector retTys(nWords, IntegerType::get(getContext(), width)); - Type retTy = retTys.size() > 1 - ? LLVM::LLVMStructType::getLiteral(getContext(), retTys) - : retTys[0]; - - // TODO: if (has_l2_evict_policy) - auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), - LLVM::AsmDialect::AD_ATT); - Value ret = ptxBuilder.launch(rewriter, loc, retTy); - - // --- - // extract and store return values - // --- - SmallVector rets; - for (unsigned int ii = 0; ii < nWords; ++ii) { - Value curr; - if (retTy.isa()) { - curr = extract_val(IntegerType::get(getContext(), width), ret, - rewriter.getI64ArrayAttr(ii)); - } else { - curr = ret; - } - curr = bitcast( - LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits), - curr); - rets.push_back(curr); - } - int tmp = width / valueElemNbits; - for (size_t ii = 0; ii < vec; ++ii) { - Value vecIdx = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp); - Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx); - loadedVals.push_back(loaded); - } - } // end vec - - Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); - Value resultStruct = - getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy); - rewriter.replaceOp(op, {resultStruct}); - return success(); - } -}; - struct GetProgramIdOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index 019a91d01..65c7b7c3f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -1,4 +1,5 @@ #include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include @@ -23,6 +24,11 @@ struct CoalescePass : public TritonGPUCoalesceBase { std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) { return contiguity[x] > contiguity[y]; }); + + int numElems = product(origType.getShape()); + int numThreads = numWarps * 32; + int numElemsPerThread = std::max(numElems / numThreads, 1); + // Thread tile size depends on memory alignment SmallVector sizePerThread(rank, 1); PointerType ptrType = origType.getElementType().cast(); @@ -31,7 +37,8 @@ struct CoalescePass : public TritonGPUCoalesceBase { unsigned maxContig = info.getContiguity(order[0]); unsigned alignment = std::min(maxMultiple, maxContig); unsigned perThread = std::min(alignment, 128 / numBits); - sizePerThread[order[0]] = perThread; + sizePerThread[order[0]] = std::min(perThread, numElemsPerThread); + SmallVector dims(rank); std::iota(dims.begin(), dims.end(), 0); // create encoding diff --git a/python/tests/test_vecadd.py b/python/tests/test_vecadd.py index 187dc115f..fb8fd9569 100644 --- a/python/tests/test_vecadd.py +++ b/python/tests/test_vecadd.py @@ -188,7 +188,7 @@ def test_vecadd_no_scf(num_warps, block_size, shape): [2, 256, (3, 256 + 7)], [4, 256, (3, 256 + 7)], ]) -def test_vecadd__no_scf_masked(num_warps, block_size, shape): +def test_vecadd_no_scf_masked(num_warps, block_size, shape): vecadd_no_scf_tester(num_warps, block_size, shape) diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 351866f7d..0f2a45f03 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -50,3 +50,92 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t tt.store %19, %20, %cst : tensor<128x128xf32> return } + +// ----- + +module { + +// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer. +func @store_constant_align(%addr: !tt.ptr {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) { + // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] + %pid = tt.get_program_id {axis = 0 : i32} : i32 + // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] + %c128_i32 = arith.constant 128 : i32 + // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] + %1 = arith.muli %pid, %c128_i32 : i32 + // CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1] + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] + %3 = tt.splat %1 : (i32) -> tensor<128xi32> + // CHECK-NEXT: Contiguity: [128] ; Divisibility: [128] ; Constancy: [1] + %4 = arith.addi %3, %2 : tensor<128xi32> + // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] + %5 = tt.splat %addr : (!tt.ptr) -> tensor<128x!tt.ptr> + // CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1] + %6 = tt.addptr %5, %4 : tensor<128x!tt.ptr> + // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] + %9 = tt.splat %n : (i32) -> tensor<128xi32> + // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [16] + %mask = arith.cmpi slt, %4, %9 : tensor<128xi32> + // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] + %cst = arith.constant dense<0.0> : tensor<128xf32> + tt.store %5, %cst, %mask : tensor<128xf32> + return +} + +} + +// ----- + +// This IR is dumped from vecadd test. +// Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask. +func @vecadd_mask_align_16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %3 = tt.splat %1 : (i32) -> tensor<64xi32> + %4 = arith.addi %3, %2 : tensor<64xi32> + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<64x!tt.ptr> + %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr> + %9 = tt.splat %n_elements : (i32) -> tensor<64xi32> + // CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [16] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> ) + %mask = arith.cmpi slt, %4, %9 : tensor<64xi32> + %11 = tt.load %6, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> + %12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> + %13 = arith.addf %11, %12 : tensor<64xf32> + %14 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x!tt.ptr> + // CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr> ) + %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr> + tt.store %15, %13, %mask : tensor<64xf32> + return +} + +// ----- + +// This IR is dumped from vecadd test. +// Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default. +func @vecadd_mask_align_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %3 = tt.splat %1 : (i32) -> tensor<64xi32> + %4 = arith.addi %3, %2 : tensor<64xi32> + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<64x!tt.ptr> + %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr> + %9 = tt.splat %n_elements : (i32) -> tensor<64xi32> + // CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> ) + %10 = arith.cmpi slt, %4, %9 : tensor<64xi32> + %11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> + %12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> + %13 = arith.addf %11, %12 : tensor<64xf32> + %14 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x!tt.ptr> + %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr> + tt.store %15, %13, %10 : tensor<64xf32> + return +} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 1c182e79a..050c95845 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -161,6 +161,37 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { // ----- +// This test verifies the vectorization of Load and Store Ops. +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> +// Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1. +module attributes {"triton_gpu.num-warps" = 2 : i32} { + func @vecadd_masked_vec1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked> + %3 = tt.splat %1 : (i32) -> tensor<64xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<64xi32, #blocked> + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<64x!tt.ptr, #blocked> + %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr, #blocked> + %9 = tt.splat %n_elements : (i32) -> tensor<64xi32, #blocked> + %10 = "triton_gpu.cmpi"(%4, %9) {predicate = 2 : i64} : (tensor<64xi32, #blocked>, tensor<64xi32, #blocked>) -> tensor<64xi1, #blocked> + // load op has a vector width = 1 due to the %mask's alignment + // CHECK: ld.global.b32 + %11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32, #blocked> + %12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32, #blocked> + %13 = arith.addf %11, %12 : tensor<64xf32, #blocked> + %14 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x!tt.ptr, #blocked> + %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr, #blocked> + tt.store %15, %13, %10 : tensor<64xf32, #blocked> + return + } +} + +// ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: global_load_store_vec8 @@ -682,4 +713,4 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { %0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> return } -} \ No newline at end of file +}