diff --git a/include/triton/Conversion/Passes.td b/include/triton/Conversion/Passes.td index 504dc9cb9..2fa727b70 100644 --- a/include/triton/Conversion/Passes.td +++ b/include/triton/Conversion/Passes.td @@ -33,11 +33,13 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()"; let dependentDialects = ["mlir::arith::ArithmeticDialect", - "mlir::StandardOpsDialect", + "mlir::gpu::GPUDialect", "mlir::scf::SCFDialect", "mlir::LLVM::LLVMDialect", "mlir::triton::TritonDialect", - "mlir::triton::gpu::TritonGPUDialect"]; + "mlir::triton::gpu::TritonGPUDialect", + "mlir::NVVM::NVVMDialect", + "mlir::StandardOpsDialect"]; } #endif diff --git a/lib/Conversion/PassDetail.h b/lib/Conversion/PassDetail.h index ca6b9d9c4..10137921f 100644 --- a/lib/Conversion/PassDetail.h +++ b/lib/Conversion/PassDetail.h @@ -1,7 +1,9 @@ #ifndef TRITON_CONVERSION_PASSDETAIL_H #define TRITON_CONVERSION_PASSDETAIL_H +#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Pass/Pass.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 358604975..ec464b99c 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -13,6 +13,9 @@ add_mlir_conversion_library(TritonGPUToLLVM LINK_LIBS PUBLIC MLIRIR MLIRPass + MLIRGPUOps + MLIRGPUToNVVMTransforms + MLIRGPUTransforms TritonIR TritonGPUIR TritonGPUTransforms diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index a39f08885..9de3f8a1e 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1,9 +1,14 @@ #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" + #include "../PassDetail.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -12,6 +17,9 @@ using namespace mlir; using namespace mlir::triton; +using ::mlir::triton::gpu::TritonGPUBlockedEncodingAttr; +using ::mlir::triton::gpu::TritonGPUMmaEncodingAttr; +using ::mlir::triton::gpu::TritonGPUSharedEncodingAttr; namespace mlir { namespace LLVM { @@ -23,6 +31,8 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; } namespace { +class TritonGPUToLLVMTypeConverter; + // The following code are borrowed from mlir project including the following // functions or classes: // - filterFuncAttributes @@ -134,17 +144,7 @@ protected: static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface"; struct FuncOpConversion : public FuncOpConversionBase { FuncOpConversion(LLVMTypeConverter &converter, int numWarps) - : FuncOpConversionBase(converter), NumWarps(numWarps) { - mlir::ConvertToLLVMPattern::getTypeConverter()->addConversion( - [&](triton::PointerType type) { - return convertTritonPointerType(type); - }); - } - - Type convertTritonPointerType(triton::PointerType type) { - return LLVM::LLVMPointerType::get(type.getPointeeType(), - type.getAddressSpace()); - } + : FuncOpConversionBase(converter), NumWarps(numWarps) {} LogicalResult matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, @@ -172,7 +172,7 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(ReturnOp op, OpAdaptor adapter, + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); unsigned numArguments = op.getNumOperands(); @@ -208,12 +208,631 @@ int extractNumWarps(mlir::ModuleOp module) { return numWarps; } -} // namespace +template +static SmallVector getMultiDimIndex(T linear_index, ArrayRef shape) { + // sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1} + size_t rank = shape.size(); + T acc_mul = 1; + for (size_t i = 1; i < rank; ++i) { + acc_mul *= shape[i]; + } + T linear_remain = linear_index; + SmallVector multidim_index(rank); + for (size_t i = 0; i < rank; ++i) { + multidim_index[i] = linear_remain / acc_mul; + linear_remain = linear_remain % acc_mul; + if (i != (rank - 1)) { + acc_mul = acc_mul / shape[i + 1]; + } + } + return multidim_index; +} + +template +static T getLinearIndex(ArrayRef multidim_index, ArrayRef shape) { + assert(multidim_index.size() == shape.size()); + // sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1} + size_t rank = shape.size(); + T acc_mul = 1; + for (size_t i = 1; i < rank; ++i) { + acc_mul *= shape[i]; + } + T linear_index = 0; + for (size_t i = 0; i < rank; ++i) { + linear_index += multidim_index[i] * acc_mul; + if (i != (rank - 1)) { + acc_mul = acc_mul / shape[i + 1]; + } + } + return linear_index; +} + +static unsigned getElemsPerThread(const TritonGPUBlockedEncodingAttr &layout, + ArrayRef shape) { + unsigned elems = 1; + size_t rank = shape.size(); + assert(rank == layout.getThreadsPerWarp().size()); + for (size_t d = 0; d < rank; ++d) { + elems *= + shape[d] / (layout.getThreadsPerWarp()[d] * layout.getWarpsPerCTA()[d]); + } + return elems; +} + +static Value createIndexAttrConstant(OpBuilder &builder, Location loc, + Type resultType, int64_t value) { + return builder.create( + loc, resultType, builder.getIntegerAttr(resultType, value)); +} + +template +class ConvertTritonGPUOpToLLVMPattern + : public ConvertOpToLLVMPattern { +public: + using OpAdaptor = typename SourceOp::Adaptor; + + explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + + SmallVector + getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems, + ConversionPatternRewriter &rewriter) const { + SmallVector results(elems); + for (unsigned i = 0; i < elems; ++i) { + Type type = + llvmStruct.getType().cast().getBody()[i]; + results[i] = rewriter.create( + loc, type, llvmStruct, rewriter.getI64ArrayAttr(i)); + } + return results; + } + + Value getStructFromElements(Location loc, ValueRange resultVals, + ConversionPatternRewriter &rewriter, + Type structType) const { + Value llvmStruct = rewriter.create(loc, structType); + for (auto v : llvm::enumerate(resultVals)) { + llvmStruct = rewriter.create( + loc, structType, llvmStruct, v.value(), + rewriter.getI64ArrayAttr(v.index())); + } + return llvmStruct; + } + + SmallVector delinearize(ConversionPatternRewriter &rewriter, + Location loc, Value linear, + ArrayRef shape, + ArrayRef order) const { + unsigned rank = shape.size(); + assert(rank == order.size()); + SmallVector reordered(rank); + for (unsigned i = 0; i < rank; ++i) { + reordered[i] = shape[order[i]]; + } + return delinearize(rewriter, loc, linear, reordered); + } + + SmallVector delinearize(ConversionPatternRewriter &rewriter, + Location loc, Value linear, + ArrayRef shape) const { + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + if (rank == 1) { + multiDim[0] = linear; + } else { + Value remained = linear; + for (auto &&en : llvm::enumerate(llvm::reverse(shape.drop_front()))) { + Value dimSize = createIndexAttrConstant( + rewriter, loc, this->getTypeConverter()->getIndexType(), + en.value()); + multiDim[rank - 1 - en.index()] = + rewriter.create(loc, remained, dimSize); + remained = rewriter.create(loc, remained, dimSize); + } + multiDim[0] = remained; + } + return multiDim; + } + + // Emit indices calculation within each ConversionPattern + // TODO: [goostavz] Double confirm the redundant indices calculations will + // be eliminated in the consequent MLIR/LLVM optimization + SmallVector> emitIndicesForBlockedLayout( + Location loc, ConversionPatternRewriter &b, + const TritonGPUBlockedEncodingAttr &blocked_layout, + ArrayRef shape) const { + auto llvmIndexTy = this->getTypeConverter()->getIndexType(); + auto cast = b.create( + loc, TypeRange{llvmIndexTy}, + ValueRange{b.create<::mlir::gpu::ThreadIdOp>( + loc, b.getIndexType(), ::mlir::gpu::Dimension::x)}); + Value threadId = cast.getResult(0); + Value warpSize = createIndexAttrConstant(b, loc, llvmIndexTy, 32); + Value laneId = b.create(loc, threadId, warpSize); + Value warpId = b.create(loc, threadId, warpSize); + auto sizePerThread = blocked_layout.getSizePerThread(); + auto threadsPerWarp = blocked_layout.getThreadsPerWarp(); + auto warpsPerCTA = blocked_layout.getWarpsPerCTA(); + auto order = blocked_layout.getOrder(); + unsigned rank = shape.size(); + SmallVector threadIds(rank); + + // step 1, delinearize threadId to get the base index + SmallVector multiDimWarpId = + delinearize(b, loc, warpId, warpsPerCTA, order); + SmallVector multiDimThreadId = + delinearize(b, loc, laneId, threadsPerWarp, order); + SmallVector multiDimBase(rank); + for (unsigned k = 0; k < rank; ++k) { + // multiDimBase[k] = (multiDimThreadId[k] + multiDimWarpId[k] * + // threadsPerWarp[k]) * + // sizePerThread[k]; + Value threadsPerWarpK = + createIndexAttrConstant(b, loc, llvmIndexTy, threadsPerWarp[k]); + Value sizePerThreadK = + createIndexAttrConstant(b, loc, llvmIndexTy, sizePerThread[k]); + multiDimBase[k] = b.create( + loc, sizePerThreadK, + b.create( + loc, multiDimThreadId[k], + b.create(loc, multiDimWarpId[k], threadsPerWarpK))); + } + + // step 2, get offset of each element + unsigned elemsPerThread = 1; + SmallVector> offset(rank); + SmallVector multiDimElemsPerThread(rank); + for (unsigned k = 0; k < rank; ++k) { + multiDimElemsPerThread[k] = shape[k] / threadsPerWarp[k] / warpsPerCTA[k]; + elemsPerThread *= multiDimElemsPerThread[k]; + for (unsigned blockOffset = 0; + blockOffset < + shape[k] / (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]); + ++blockOffset) { + for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; + ++warpOffset) { + for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k]; + ++threadOffset) { + for (unsigned elemOffset = 0; elemOffset < sizePerThread[k]; + ++elemOffset) { + offset[k].push_back(blockOffset * sizePerThread[k] * + threadsPerWarp[k] * warpsPerCTA[k] + + warpOffset * sizePerThread[k] * + threadsPerWarp[k] + + threadOffset * sizePerThread[k] + elemOffset); + } + } + } + } + } + + // step 3, add offset to base, and reorder the sequence of indices, + // to guarantee that elems in a same sizePerThread are adjacent in + // order + SmallVector> multiDimIdx(elemsPerThread); + unsigned accumSizePerThread = + std::accumulate(sizePerThread.begin(), sizePerThread.end(), 1, + std::multiplies()); + SmallVector threadsPerDim(rank); + for (unsigned k = 0; k < rank; ++k) { + threadsPerDim[k] = shape[k] / sizePerThread[k]; + } + for (unsigned n = 0; n < elemsPerThread; ++n) { + unsigned linearNanoTileId = n / accumSizePerThread; + unsigned linearElemsInNanoTileId = n % accumSizePerThread; + SmallVector multiDimNanoTileId = + getMultiDimIndex(linearNanoTileId, threadsPerDim); + SmallVector multiElemsInNanoTileId = + getMultiDimIndex(linearElemsInNanoTileId, sizePerThread); + multiDimIdx[n].resize(rank); + for (unsigned k = 0; k < rank; ++k) { + unsigned reorderedMultiDimId = + multiDimNanoTileId[k] * + (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) + + multiElemsInNanoTileId[k]; + multiDimIdx[n][k] = b.create( + loc, multiDimBase[k], + createIndexAttrConstant(b, loc, llvmIndexTy, + offset[k][reorderedMultiDimId])); + } + } + + return multiDimIdx; + } +}; + +struct BroadcastOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::BroadcastOp>::ConvertTritonGPUOpToLLVMPattern; + + // Following the order of indices in the legacy code, a broadcast of: + // [s(0), s(1) ... s(k-1), 1, s(k+1), s(k+2) ... s(n-1)] + // => + // [s(0), s(1) ... s(k-1), s(k), s(k+1), s(k+2) ... s(n-1)] + // + // logically maps to a broadcast within a thread's scope: + // [cta(0)..cta(k-1), 1,cta(k+1)..cta(n-1),spt(0)..spt(k-1), + // 1,spt(k+1)..spt(n-1)] + // => + // [cta(0)..cta(k-1),cta(k),cta(k+1)..cta(n-1),spt(0)..spt(k-1),spt(k),spt(k+1)..spt(n-1)] + // + // regardless of the order of the layout + // + LogicalResult + matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Value src = adaptor.src(); + Value result = op.result(); + auto srcTy = op.src().getType().cast(); + auto resultTy = result.getType().cast(); + auto srcLayout = + srcTy.getEncoding().dyn_cast(); + auto resultLayout = + resultTy.getEncoding().dyn_cast(); + assert(srcLayout && (srcLayout == resultLayout) && + "Unexpected layout of BroadcastOp"); + auto srcShape = srcTy.getShape(); + auto resultShape = resultTy.getShape(); + unsigned rank = srcTy.getRank(); + // TODO: [goostavz] double confirm the op semantics with Phil + assert(rank == resultTy.getRank()); + + SmallVector srcLogicalShape(2 * rank); + SmallVector resultLogicalShape(2 * rank); + SmallVector broadcastDims; + SmallVector broadcastSizes; + int64_t duplicates = 1; + for (unsigned d = 0; d < rank; ++d) { + int64_t numCtas = resultShape[d] / (resultLayout.getSizePerThread()[d] * + resultLayout.getThreadsPerWarp()[d] * + resultLayout.getWarpsPerCTA()[d]); + if (srcShape[d] != resultShape[d]) { + assert(srcShape[d] == 1); + broadcastDims.push_back(d); + broadcastSizes.push_back(resultShape[d]); + srcLogicalShape[d] = 1; + srcLogicalShape[d + rank] = 1; + duplicates *= resultShape[d]; + } else { + srcLogicalShape[d] = numCtas; + srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d]; + } + resultLogicalShape[d] = numCtas; + resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d]; + } + unsigned srcElems = getElemsPerThread(srcLayout, srcShape); + auto elemTy = resultTy.getElementType(); + auto srcVals = getElementsFromStruct(loc, src, srcElems, rewriter); + unsigned resultElems = getElemsPerThread(resultLayout, resultShape); + SmallVector resultVals(resultElems); + for (unsigned i = 0; i < srcElems; ++i) { + auto srcMultiDim = getMultiDimIndex(i, srcLogicalShape); + auto resultMultiDim = srcMultiDim; + for (int64_t j = 0; j < duplicates; ++j) { + auto bcastMultiDim = getMultiDimIndex(j, broadcastSizes); + for (auto bcastDim : llvm::enumerate(broadcastDims)) { + resultMultiDim[bcastDim.value()] = bcastMultiDim[bcastDim.index()]; + } + auto resultLinearIndex = + getLinearIndex(resultMultiDim, resultLogicalShape); + resultVals[resultLinearIndex] = srcVals[i]; + } + } + auto llvmStructTy = getTypeConverter()->convertType(resultTy); + Value resultStruct = + getStructFromElements(loc, resultVals, rewriter, llvmStructTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct ViewOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::ViewOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // We cannot directly + // rewriter.replaceOp(op, adaptor.src()); + // due to MLIR's restrictions + Location loc = op->getLoc(); + auto resultTy = op.getType().cast(); + auto resultLayout = + resultTy.getEncoding().dyn_cast(); + auto resultShape = resultTy.getShape(); + unsigned elems = getElemsPerThread(resultLayout, resultShape); + Type elemTy = + this->getTypeConverter()->convertType(resultTy.getElementType()); + SmallVector types(elems, elemTy); + Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types); + auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter); + Value view = getStructFromElements(loc, vals, rewriter, structTy); + rewriter.replaceOp(op, view); + return success(); + } +}; + +struct MakeRangeOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::MakeRangeOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto rankedTy = op.result().getType().dyn_cast(); + auto shape = rankedTy.getShape(); + auto blocked_layout = + rankedTy.getEncoding().dyn_cast(); + auto elemTy = rankedTy.getElementType(); + assert(elemTy.isInteger(32)); + Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start()); + auto idxs = + emitIndicesForBlockedLayout(loc, rewriter, blocked_layout, shape); + unsigned elems = idxs.size(); + SmallVector retVals(elems); + for (auto multiDim : llvm::enumerate(idxs)) { + assert(multiDim.value().size() == 1); + retVals[multiDim.index()] = + rewriter.create(loc, multiDim.value()[0], start); + } + SmallVector types(elems, elemTy); + Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types); + Value result = getStructFromElements(loc, retVals, rewriter, structTy); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct LoadOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Value ptr = adaptor.ptr(); + Value mask = adaptor.mask(); + Value other = adaptor.other(); + auto resultTy = op.result().getType().cast(); + auto blockedLayout = + resultTy.getEncoding().dyn_cast(); + auto shape = resultTy.getShape(); + + // TODO: Handle AxisInfo + // vecWidth = std::min(nts, aln) + // TODO: special processing for mma_first_row in legacy codes + assert(blockedLayout && "LoadOp only accepts blocked_layout"); + unsigned vecWidth = + blockedLayout.getSizePerThread()[blockedLayout.getOrder()[0]]; + + auto elemTy = resultTy.getElementType(); + unsigned numElems = getElemsPerThread(blockedLayout, shape); + auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter); + auto maskVals = getElementsFromStruct(loc, mask, numElems, rewriter); + auto otherVals = getElementsFromStruct(loc, other, numElems, rewriter); + unsigned nbits = elemTy.isa() + ? elemTy.cast().getWidth() + : elemTy.cast().getWidth(); + // unsigned dtsize = nbits / 8; + int max_word_width = std::max(32, nbits); + int tot_width = nbits * vecWidth; + int width = std::min(tot_width, max_word_width); + int n_words = std::max(1, tot_width / width); + // TODO: currently disable until supported in `store` + bool has_l2_evict_policy = false; + + // TODO: (goostavz) handle when other is const but not splat, which + // should be rarely seen + bool otherIsSplatConstInt = false; + DenseElementsAttr constAttr; + int64_t splatVal = 0; + if (elemTy.isa() && + matchPattern(op.other(), m_Constant(&constAttr)) && + constAttr.isSplat()) { + otherIsSplatConstInt = true; + splatVal = constAttr.getSplatValue().getSExtValue(); + } + + SmallVector loadedVals; + for (size_t i = 0; i < numElems; i += vecWidth) { + Value ptr = ptrVals[i]; + // TODO: Handle the optimization if ptr is from GEP and the idx is + // constant + // This should be a canonicalization pattern in LLVM Dialect + unsigned in_off = 0; + Value pred = maskVals[i]; + + // --- + // create inline asm string + // --- + // TODO: (Superjomn) refactor with AsmInstr abstraction + std::ostringstream asmOss; + asmOss << "@$" << n_words; // predicate + asmOss << " ld"; + if (op.isVolatile()) { + asmOss << ".volatile"; + } + asmOss << ".global"; + if (op.cache() == triton::CacheModifier::CA) + asmOss << ".ca"; + if (op.cache() == triton::CacheModifier::CG) + asmOss << ".cg"; + if (op.evict() == triton::EvictionPolicy::EVICT_FIRST) + asmOss << ".L1::evict_first"; + if (op.evict() == triton::EvictionPolicy::EVICT_LAST) + asmOss << ".L1::evict_last"; + if (has_l2_evict_policy) + asmOss << ".L2::cache_hint"; + if (n_words > 1) + asmOss << ".v" << n_words; // vector width + asmOss << ".b" << width; // word size + asmOss << " {"; + for (int i = 0; i < n_words; i++) { // return values + if (i > 0) + asmOss << ","; + asmOss << "$" << i; + } + asmOss << "}"; + asmOss << ", [ $" << n_words + 1; // load + asmOss << " + " << in_off << "]"; // constant offset + if (has_l2_evict_policy) + asmOss << ", $" << n_words + 2; + asmOss << ";"; + SmallVector others; + for (size_t ii = 0; ii < n_words; ii++) { + size_t size = width / nbits; + auto vecTy = LLVM::getFixedVectorType(elemTy, size); + Value v = rewriter.create(loc, vecTy); + for (size_t s = 0; s < size; s++) { + Value falseVal = otherVals[i + ii * size + s]; + Value sVal = createIndexAttrConstant( + rewriter, loc, this->getTypeConverter()->getIndexType(), s); + v = rewriter.create(loc, vecTy, v, falseVal, + sVal); + } + v = rewriter.create( + loc, IntegerType::get(getContext(), width), v); + asmOss << "\n "; + asmOss << "@!$" << n_words << " mov.u" << width; + asmOss << " $" << ii << ", "; + std::ios_base::fmtflags flags(asmOss.flags()); + if (otherIsSplatConstInt) + asmOss << "0x" << std::hex << splatVal; + else { + asmOss << "$" << n_words + has_l2_evict_policy + 2 + ii; + others.push_back(v); + } + asmOss.flags(flags); + asmOss << ";"; + } + // --- + // create inline ASM signature + // --- + SmallVector retTys(n_words, IntegerType::get(getContext(), width)); + Type retTy = retTys.size() > 1 + ? LLVM::LLVMStructType::getLiteral(getContext(), retTys) + : retTys[0]; + // --- + // create inline ASM constraints + // --- + std::string asmCstrt; + for (int ii = 0; ii < n_words; ii++) { + if (ii > 0) + asmCstrt += ","; + asmCstrt += (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); + } + asmCstrt += ",b,l"; + for (size_t ii = 0; ii < others.size(); ii++) { + asmCstrt += ","; + asmCstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c"); + } + if (has_l2_evict_policy) { + asmCstrt += ",l"; + } + // --- + // finally call inline ASM + // --- + SmallVector args = {pred, ptr}; + auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), + LLVM::AsmDialect::AD_ATT); + auto inlineAsmOp = rewriter.create( + loc, retTy, /*operands=*/args, /*asm_string=*/asmOss.str(), + /*constraints=*/asmCstrt, /*has_side_effects=*/true, + /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, + /*operand_attrs=*/ArrayAttr()); + Value ret = inlineAsmOp.getResult(0); + // --- + // extract and store return values + // --- + SmallVector rets; + for (unsigned int ii = 0; ii < n_words; ii++) { + Value curr = nullptr; + if (retTy.isa()) { + curr = rewriter.create( + loc, IntegerType::get(getContext(), width), ret, + rewriter.getI64ArrayAttr(ii)); + } else { + curr = ret; + } + curr = rewriter.create( + loc, LLVM::getFixedVectorType(elemTy, width / nbits), curr); + rets.push_back(curr); + } + int tmp = (width / nbits); + for (size_t ii = 0; ii < vecWidth; ii++) { + Value vecIdx = createIndexAttrConstant( + rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp); + Value loaded = rewriter.create( + loc, elemTy, rets[ii / tmp], vecIdx); + loadedVals.push_back(loaded); + } + } + Type llvmResultStructTy = getTypeConverter()->convertType(resultTy); + Value resultStruct = + getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter { +public: + using TypeConverter::convertType; + + TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis = nullptr) + : LLVMTypeConverter(ctx, option, analysis) { + addConversion([&](triton::PointerType type) -> llvm::Optional { + return convertTritonPointerType(type); + }); + addConversion([&](RankedTensorType type) -> llvm::Optional { + return convertTritonTensorType(type); + }); + } + + Type convertTritonPointerType(triton::PointerType type) { + return LLVM::LLVMPointerType::get(type.getPointeeType(), + type.getAddressSpace()); + } + + llvm::Optional convertTritonTensorType(RankedTensorType type) { + Attribute layout = type.getEncoding(); + if (auto blocked_layout = layout.dyn_cast()) { + unsigned numElementsPerThread = + getElemsPerThread(blocked_layout, type.getShape()); + SmallVector types(numElementsPerThread, + convertType(type.getElementType())); + return LLVM::LLVMStructType::getLiteral(&getContext(), types); + } else if (auto mma_layout = layout.dyn_cast()) { + // TODO: Not implemented + return llvm::None; + } else if (auto shared_layout = + layout.dyn_cast()) { + // TODO: Not implemented + return llvm::None; + } + return llvm::None; + } +}; void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps) { - patterns.add<::FuncOpConversion>(typeConverter, numWarps); - patterns.add<::ReturnOpConversion>(typeConverter); + patterns.add(typeConverter); + patterns.add(typeConverter, numWarps); + patterns.add(typeConverter); + patterns.add(typeConverter); + patterns.add(typeConverter); + patterns.add(typeConverter); } class ConvertTritonGPUToLLVM @@ -225,29 +844,41 @@ public: MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); - LLVMTypeConverter typeConverter(context); + mlir::LowerToLLVMOptions option(context); + // TODO: need confirm + option.overrideIndexBitwidth(32); + TritonGPUToLLVMTypeConverter typeConverter(context, option); TritonLLVMConversionTarget target(*context, typeConverter); RewritePatternSet patterns(context); + // TODO: (goostavz) Temporarily disable this, since the lowering of + // arithmetic ops in tensor format is not complete yet. // Add arith's patterns to help convert scalar expression to LLVM. - mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, - patterns); + // mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, + // patterns); int numWarps = extractNumWarps(mod); populateTritonToLLVMPatterns(typeConverter, patterns, numWarps); + mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); if (failed(applyPartialConversion(mod, target, std::move(patterns)))) return signalPassFailure(); } }; +} // namespace + namespace mlir { TritonLLVMConversionTarget::TritonLLVMConversionTarget( MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter) : ConversionTarget(ctx), typeConverter(typeConverter) { addLegalDialect(); + addLegalDialect(); + // addIllegalDialect(); + addIllegalDialect(); + addLegalOp(); } namespace triton { diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 8f006097b..51a0e4a11 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -3,7 +3,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { -// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr) +// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr) // Here the 128 comes from the 4 in module attribute multiples 32 // CHECK: attributes {nvvm.maxntid = 128 : i32} {{.*}} func @test_empty_kernel(%lb : index, %A : !tt.ptr) { @@ -13,3 +13,128 @@ func @test_empty_kernel(%lb : index, %A : !tt.ptr) { } } // end module + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_load + func @basic_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK: llvm.inline_asm + %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> + return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: vectorized_load + func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: ld.global.v4.b32 + // CHECK: llvm.inline_asm + // CHECK-SAME: ld.global.v4.b32 + %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> + return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: vectorized_load_f16 + func @vectorized_load_f16(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: ld.global.v2.b32 + // CHECK: llvm.inline_asm + // CHECK-SAME: ld.global.v2.b32 + %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf16, #blocked0> + return + } +} + +// ----- + +// TODO: Pending on the support of isSplat constant +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: masked_load_const_other + func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> + %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> + return + } +} + +// TODO: Add a testcase to verify the optimization when ptr of the LoadOp +// is from a GEP with const idx + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_view_broadcast + func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) { + // CHECK: llvm.mlir.undef + // CHECK: %[[T0:.*]] = llvm.extractvalue + // CHECK: %[[T1:.*]] = llvm.extractvalue + %0 = tt.view %arg : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2> + // CHECK: llvm.mlir.undef + // CHECK: llvm.insertvalue %[[T0]] + // CHECK: llvm.insertvalue %[[T0]] + // CHECK: llvm.insertvalue %[[T0]] + // CHECK: llvm.insertvalue %[[T0]] + // CHECK: llvm.insertvalue %[[T1]] + // CHECK: llvm.insertvalue %[[T1]] + // CHECK: llvm.insertvalue %[[T1]] + // CHECK: llvm.insertvalue %[[T1]] + %1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2> + return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_make_range + func @basic_make_range() { + // CHECK: nvvm.read.ptx.sreg.tid.x + // CHECK: llvm.mlir.undef + // CHECK: llvm.insertvalue + // CHECK: llvm.insertvalue + %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + return + } +} + +// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +// #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +// module attributes {"triton_gpu.num-warps" = 4 : i32} { +// func @debut_kernel(%lb : index, %A : !tt.ptr, %B : !tt.ptr, %C : !tt.ptr) { +// %cst = arith.constant dense : tensor<256xi1, #blocked0> +// %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> +// %cst_1 = arith.constant dense : tensor<1024x256xi1, #blocked1> +// %cst_2 = arith.constant dense : tensor<256x2048xi1, #blocked2> +// %a_ptr_init = tt.splat %A : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> +// %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> +// %4 = tt.view %1 : (tensor<256xf32, #blocked0>) -> tensor<1x256xf32,#blocked1> +// %5 = tt.broadcast %4 : (tensor<1x256xf32,#blocked1>) -> tensor<1024x256xf32, #blocked1> +// %6 = tt.view %1 : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2> +// %7 = tt.broadcast %6 : (tensor<256x1xf32,#blocked2>) -> tensor<256x2048xf32, #blocked2> +// %b_ptr_init = tt.splat %A : (!tt.ptr) -> tensor<1024x256x!tt.ptr, #blocked1> +// %c_ptr_init = tt.splat %A : (!tt.ptr) -> tensor<256x2048x!tt.ptr, #blocked2> +// tt.store %b_ptr_init, %5, %cst_1, : tensor<1024x256xf32, #blocked1> +// tt.store %c_ptr_init, %7, %cst_2, : tensor<256x2048xf32, #blocked2> +// return +// } +// } + + + diff --git a/test/Target/tritongpu_to_llvmir.mlir b/test/Target/tritongpu_to_llvmir.mlir index 96712c112..0f03323e9 100644 --- a/test/Target/tritongpu_to_llvmir.mlir +++ b/test/Target/tritongpu_to_llvmir.mlir @@ -4,7 +4,7 @@ // CHECK-LABEL: ; ModuleID = 'LLVMDialectModule' // CHECK: define void @test_empty_kernel // CHECK: !nvvm.annotations -// CHECK: !{void (i64, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128} +// CHECK: !{void (i32, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128} module attributes {"triton_gpu.num-warps" = 4 : i32} {