diff --git a/CMakeLists.txt b/CMakeLists.txt index b85c43691..9bb7f720c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -132,14 +132,15 @@ endif() # Python module if(TRITON_BUILD_PYTHON_MODULE) message(STATUS "Adding Python module") + set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) + include_directories("." ${PYTHON_SRC_PATH}) if (PYTHON_INCLUDE_DIRS) - set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) - include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS}) + include_directories(${PYTHON_INCLUDE_DIRS}) link_directories(${PYTHON_LINK_DIRS}) + link_libraries(${PYTHON_LIBRARIES}) else() find_package(Python3 REQUIRED COMPONENTS Development) - set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) - include_directories("." ${PYTHON_SRC_PATH} ${Python3_INCLUDE_DIRS}) + include_directories(${Python3_INCLUDE_DIRS}) link_directories(${Python3_LIBRARY_DIRS}) link_libraries(${Python3_LIBRARIES}) add_link_options(${Python3_LINK_OPTIONS}) @@ -169,7 +170,10 @@ list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") include(TableGen) # required by AddMLIR include(AddLLVM) include(AddMLIR) -# include(HandleLLVMOptions) # human-friendly error message +include(HandleLLVMOptions) # human-friendly error message + +# Disable warnings that show up in external code (gtest;pybind11) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-covered-switch-default") include_directories(${MLIR_INCLUDE_DIRS}) include_directories(${LLVM_INCLUDE_DIRS}) @@ -192,7 +196,6 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) target_link_libraries(triton - ${PYTHON_LIBRARIES} TritonAnalysis TritonTransforms TritonGPUTransforms diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index f4d6d102f..b33ac5bf9 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -142,7 +142,7 @@ private: BufferT(BufferKind kind) : BufferT(kind, 0, 0) {} BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {} BufferT(BufferKind kind, size_t size, size_t offset) - : kind(kind), size(size), offset(offset), id(nextId++) {} + : kind(kind), id(nextId++), size(size), offset(offset) {} bool intersects(const BufferT &other) const { return Interval(offset, offset + size) diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index 7dfc6a08f..6026e2648 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -28,8 +28,8 @@ public: DimVectorT knownConstancy) : contiguity(knownContiguity), divisibility(knownDivisibility), constancy(knownConstancy), rank(contiguity.size()) { - assert(knownDivisibility.size() == rank); - assert(knownConstancy.size() == rank); + assert(knownDivisibility.size() == (size_t)rank); + assert(knownConstancy.size() == (size_t)rank); } // Accessors diff --git a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h index 3ad8316ba..133c7eaf3 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h @@ -15,9 +15,9 @@ class Location; namespace triton { using llvm::StringRef; -class PTXInstr; -class PTXInstrCommon; -class PTXInstrExecution; +struct PTXInstr; +struct PTXInstrCommon; +struct PTXInstrExecution; // PTXBuilder helps to manage a PTX asm program consists of one or multiple // instructions. @@ -83,7 +83,7 @@ struct PTXBuilder { Operand() = default; Operand(const Operation &) = delete; Operand(Value value, StringRef constraint) - : value(value), constraint(constraint) {} + : constraint(constraint), value(value) {} bool isList() const { return !value && constraint.empty(); } @@ -120,7 +120,7 @@ struct PTXBuilder { Operand *newListOperand(unsigned count, mlir::Value val, const std::string &constraint) { auto *list = newOperand(); - for (int i = 0; i < count; ++i) { + for (unsigned i = 0; i < count; ++i) { list->listAppend(newOperand(val, constraint)); } return list; @@ -128,7 +128,7 @@ struct PTXBuilder { Operand *newListOperand(unsigned count, const std::string &constraint) { auto *list = newOperand(); - for (int i = 0; i < count; ++i) { + for (unsigned i = 0; i < count; ++i) { list->listAppend(newOperand(constraint)); } return list; @@ -172,8 +172,8 @@ private: return argArchive.back().get(); } - friend class PTXInstr; - friend class PTXInstrCommon; + friend struct PTXInstr; + friend struct PTXInstrCommon; protected: llvm::SmallVector, 6> argArchive; @@ -209,7 +209,7 @@ protected: PTXBuilder *builder{}; llvm::SmallVector instrParts; - friend class PTXInstrExecution; + friend struct PTXInstrExecution; }; template struct PTXInstrBase : public PTXInstrCommon { @@ -309,7 +309,7 @@ struct PTXInstrExecution { PTXInstrExecution() = default; explicit PTXInstrExecution(PTXInstrCommon *instr, llvm::ArrayRef oprs) - : instr(instr), argsInOrder(oprs.begin(), oprs.end()) {} + : argsInOrder(oprs.begin(), oprs.end()), instr(instr) {} // Prefix a predicate to the instruction. PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") { diff --git a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h index ef81d82c3..7c4143c11 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h @@ -11,16 +11,12 @@ class ModuleOp; template class OperationPass; class TritonLLVMConversionTarget : public ConversionTarget { - mlir::LLVMTypeConverter &typeConverter; - public: explicit TritonLLVMConversionTarget(MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter); }; class TritonLLVMFunctionConversionTarget : public ConversionTarget { - mlir::LLVMTypeConverter &typeConverter; - public: explicit TritonLLVMFunctionConversionTarget( MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter); diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index c211d287a..d25a4d4e6 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -26,11 +26,11 @@ public: DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {} virtual LogicalResult - inferReduceOpEncoding(Attribute operandEncoding, int axis, + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, Attribute &resultEncoding) const = 0; virtual LogicalResult - inferExpandDimsOpEncoding(Attribute operandEncoding, int axis, + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, Attribute &resultEncoding) const = 0; }; diff --git a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h index 6669d2d9a..ea83e6f94 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h +++ b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h @@ -21,7 +21,6 @@ private: }; class TritonGPUConversionTarget : public ConversionTarget { - TritonGPUTypeConverter &typeConverter; public: explicit TritonGPUConversionTarget(MLIRContext &ctx, diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index f204a6ade..cafe3b777 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -82,7 +82,6 @@ SmallVector getScratchConfigForReduce(triton::ReduceOp op) { auto srcTy = op.operand().getType().cast(); auto srcLayout = srcTy.getEncoding().cast(); auto srcShape = srcTy.getShape(); - auto rank = srcShape.size(); auto axis = op.axis(); bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index d0296d5ab..3c2dca2e6 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -66,7 +66,7 @@ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { DimVectorT retContiguity; DimVectorT retDivisibility; DimVectorT retConstancy; - for (size_t d = 0; d < lhs.getRank(); ++d) { + for (int d = 0; d < lhs.getRank(); ++d) { retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); retDivisibility.push_back( gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); @@ -88,7 +88,7 @@ AxisInfo AxisInfoAnalysis::visitBinaryOp( AxisInfo::DimVectorT newContiguity; AxisInfo::DimVectorT newDivisibility; AxisInfo::DimVectorT newConstancy; - for (size_t d = 0; d < rank; ++d) { + for (int d = 0; d < rank; ++d) { newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d)); newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d)); newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d)); @@ -167,7 +167,7 @@ ChangeResult AxisInfoAnalysis::visitOperation( AxisInfo::DimVectorT contiguity; AxisInfo::DimVectorT divisibility; AxisInfo::DimVectorT constancy; - for (size_t d = 0; d < retTy.getRank(); ++d) { + for (int d = 0; d < retTy.getRank(); ++d) { contiguity.push_back(1); divisibility.push_back(opInfo.getDivisibility(0)); constancy.push_back(retTy.getShape()[d]); @@ -176,12 +176,6 @@ ChangeResult AxisInfoAnalysis::visitOperation( } // expandDims if (auto expandDims = llvm::dyn_cast(op)) { - Type _retTy = *op->result_type_begin(); - Type _opTy = *op->operand_type_begin(); - TensorType retTy = _retTy.cast(); - TensorType opTy = _opTy.cast(); - ArrayRef retShape = retTy.getShape(); - ArrayRef opShape = opTy.getShape(); AxisInfo opInfo = operands[0]->getValue(); AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); @@ -203,7 +197,7 @@ ChangeResult AxisInfoAnalysis::visitOperation( AxisInfo::DimVectorT contiguity; AxisInfo::DimVectorT divisibility; AxisInfo::DimVectorT constancy; - for (size_t d = 0; d < retTy.getRank(); ++d) { + for (int d = 0; d < retTy.getRank(); ++d) { contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); divisibility.push_back(opInfo.getDivisibility(d)); constancy.push_back(opShape[d] == 1 ? retShape[d] : 1); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 0dc827857..dcd3def89 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -28,11 +28,10 @@ bool maybeSharedAllocationOp(Operation *op) { } std::string getValueOperandName(Value value, AsmState &state) { - auto *op = value.getDefiningOp(); std::string opName; llvm::raw_string_ostream ss(opName); value.printAsOperand(ss, state); - return std::move(opName); + return opName; } } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 6d464f2a0..22c5ab720 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -254,7 +254,6 @@ protected: /// FuncOp legalization pattern that converts MemRef arguments to pointers to /// MemRef descriptors (LLVM struct data types) containing all the MemRef type /// information. -static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface"; struct FuncOpConversion : public FuncOpConversionBase { FuncOpConversion(LLVMTypeConverter &converter, int numWarps, PatternBenefit benefit) @@ -292,7 +291,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> { LogicalResult matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); unsigned numArguments = op.getNumOperands(); // Currently, Triton kernel function always return nothing. @@ -482,7 +480,6 @@ public: ConversionPatternRewriter &rewriter, const BlockedEncodingAttr &blocked_layout, ArrayRef shape) const { - auto llvmIndexTy = this->getTypeConverter()->getIndexType(); Value threadId = getThreadId(rewriter, loc); Value warpSize = idx_val(32); Value laneId = urem(threadId, warpSize); @@ -654,7 +651,6 @@ public: auto bufferId = allocation->getBufferId(value); assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); size_t offset = allocation->getOffset(bufferId); - auto llvmIndexTy = this->getTypeConverter()->getIndexType(); Value offVal = idx_val(offset); Value base = gep(ptrTy, smem, offVal); return base; @@ -684,7 +680,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, auto tensorTy = resType.cast(); if (tensorTy.getEncoding().isa()) { auto tensorTy = resType.cast(); - auto layout = tensorTy.getEncoding(); auto srcType = typeConverter->convertType(elemType); auto llSrc = bitcast(srcType, constVal); size_t elemsPerThread = getElemsPerThread(tensorTy); @@ -842,7 +837,6 @@ struct LoadOpConversion LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - MLIRContext *ctx = rewriter.getContext(); auto loc = op->getLoc(); // original values @@ -897,12 +891,11 @@ struct LoadOpConversion // TODO: optimization when ptr is GEP with constant offset size_t in_off = 0; - const int maxWordWidth = std::max(32, valueElemNbits); - const int totalWidth = valueElemNbits * vec; - const int width = std::min(totalWidth, maxWordWidth); - const int nWords = std::max(1, totalWidth / width); - const int wordNElems = width / valueElemNbits; - const int vecNElems = totalWidth / valueElemNbits; + const size_t maxWordWidth = std::max(32, valueElemNbits); + const size_t totalWidth = valueElemNbits * vec; + const size_t width = std::min(totalWidth, maxWordWidth); + const size_t nWords = std::max(1, totalWidth / width); + const size_t wordNElems = width / valueElemNbits; assert(wordNElems * nWords * numVecs == numElems); // TODO(Superjomn) Add cache policy fields to StoreOp. @@ -921,7 +914,7 @@ struct LoadOpConversion // prepare asm operands auto *dstsOpr = ptxBuilder.newListOperand(); - for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) { + for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations dstsOpr->listAppend(opr); } @@ -988,8 +981,8 @@ struct LoadOpConversion : retTys[0]; // TODO: if (has_l2_evict_policy) - auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), - LLVM::AsmDialect::AD_ATT); + // auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), + // LLVM::AsmDialect::AD_ATT); Value ret = ptxBuilder.launch(rewriter, loc, retTy); // --- @@ -1080,27 +1073,25 @@ struct StoreOpConversion // TODO: optimization when ptr is AddPtr with constant offset size_t in_off = 0; - const int maxWordWidth = std::max(32, valueElemNbits); - const int totalWidth = valueElemNbits * vec; - const int width = std::min(totalWidth, maxWordWidth); - const int nWords = std::max(1, totalWidth / width); - const int wordNElems = width / valueElemNbits; - const int vecNElems = totalWidth / valueElemNbits; + const size_t maxWordWidth = std::max(32, valueElemNbits); + const size_t totalWidth = valueElemNbits * vec; + const size_t width = std::min(totalWidth, maxWordWidth); + const size_t nWords = std::max(1, totalWidth / width); + const size_t wordNElems = width / valueElemNbits; assert(wordNElems * nWords * numVecs == numElems); // TODO(Superjomn) Add cache policy fields to StoreOp. // TODO(Superjomn) Deal with cache policy here. - const bool hasL2EvictPolicy = false; Type valArgTy = IntegerType::get(ctx, width); auto wordTy = vec_ty(valueElemTy, wordNElems); SmallVector> asmArgs; - for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) { + for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { // llWord is a width-len composition Value llWord = rewriter.create(loc, wordTy); // Insert each value element to the composition - for (int elemIdx = 0; elemIdx < wordNElems; ++elemIdx) { + for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) { const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx; assert(elemOffset < valueElems.size()); Value elem = valueElems[elemOffset]; @@ -1220,7 +1211,6 @@ struct BroadcastOpConversion } unsigned srcElems = getElemsPerThread(srcTy); - auto elemTy = resultTy.getElementType(); auto srcVals = getElementsFromStruct(loc, src, rewriter); unsigned resultElems = getElemsPerThread(resultTy); SmallVector resultVals(resultElems); @@ -1282,8 +1272,6 @@ private: LogicalResult ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto srcTy = op.operand().getType().cast(); - auto rank = srcTy.getShape().size(); if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension return matchAndRewriteFast(op, adaptor, rewriter); return matchAndRewriteBasic(op, adaptor, rewriter); @@ -1332,7 +1320,6 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter, Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter, Location loc, Value val, int i) const { - MLIRContext *ctx = rewriter.getContext(); unsigned bits = val.getType().getIntOrFloatBitWidth(); if (bits == 64) { @@ -1439,7 +1426,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( barrier(); SmallVector resultVals(resultElems); - for (int i = 0; i < resultElems; i++) { + for (size_t i = 0; i < resultElems; i++) { SmallVector readIdx = resultIndices[i]; readIdx.insert(readIdx.begin() + axis, ints[0]); Value readOffset = linearize(rewriter, loc, readIdx, smemShape); @@ -1471,7 +1458,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( auto srcTy = op.operand().getType().cast(); auto srcLayout = srcTy.getEncoding().cast(); auto srcShape = srcTy.getShape(); - auto srcOrder = srcLayout.getOrder(); auto threadsPerWarp = srcLayout.getThreadsPerWarp(); auto warpsPerCTA = srcLayout.getWarpsPerCTA(); @@ -1579,7 +1565,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( barrier(); SmallVector resultVals(resultElems); - for (int i = 0; i < resultElems; i++) { + for (size_t i = 0; i < resultElems; i++) { SmallVector readIdx = resultIndices[i]; readIdx.insert(readIdx.begin() + axis, i32_val(0)); Value readOffset = linearize(rewriter, loc, readIdx, smemShape); @@ -1619,7 +1605,6 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern { // due to MLIR's restrictions Location loc = op->getLoc(); auto resultTy = op.getType().template cast(); - auto resultShape = resultTy.getShape(); unsigned elems = getElemsPerThread(resultTy); Type elemTy = this->getTypeConverter()->convertType(resultTy.getElementType()); @@ -1698,7 +1683,6 @@ struct AddPtrOpConversion auto resultLayout = resultTensorTy.getEncoding().dyn_cast(); assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion"); - auto resultShape = resultTensorTy.getShape(); unsigned elems = getElemsPerThread(resultTy); Type elemTy = getTypeConverter()->convertType(resultTensorTy.getElementType()); @@ -1821,7 +1805,7 @@ protected: SmallVector> operands(elems); for (auto operand : adaptor.getOperands()) { auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter); - for (int i = 0; i < elems; ++i) { + for (size_t i = 0; i < elems; ++i) { operands[i].push_back(sub_operands[i]); } } @@ -1931,6 +1915,7 @@ struct CmpFOpConversion __PRED_ENUM(ORD, ord); __PRED_ENUM(UEQ, ueq); __PRED_ENUM(UGT, ugt); + __PRED_ENUM(UGE, uge); __PRED_ENUM(ULT, ult); __PRED_ENUM(ULE, ule); __PRED_ENUM(UNE, une); @@ -2034,7 +2019,6 @@ void ConvertLayoutOpConversion::processReplica( auto rank = type.getRank(); auto sizePerThread = getSizePerThread(layout); auto accumSizePerThread = product(sizePerThread); - auto llvmIndexTy = getTypeConverter()->getIndexType(); SmallVector numCTAs(rank); auto shapePerCTA = getShapePerCTA(layout); for (unsigned d = 0; d < rank; ++d) { @@ -2048,7 +2032,6 @@ void ConvertLayoutOpConversion::processReplica( multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout( loc, rewriter, blockedLayout, type.getShape()); } else if (sliceLayout) { - unsigned dim = sliceLayout.getDim(); auto parent = sliceLayout.getParent(); if (auto blockedParent = parent.dyn_cast()) { SmallVector paddedShape = @@ -2200,7 +2183,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed( } // Potentially we need to store for multiple CTAs in this replication unsigned accumNumReplicates = product(numReplicates); - unsigned elems = getElemsPerThread(srcTy); + // unsigned elems = getElemsPerThread(srcTy); auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter); unsigned inVec = 0; unsigned outVec = 0; @@ -2367,17 +2350,17 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( // Data loader for mma.16816 instruction. class MMA16816SmemLoader { public: - MMA16816SmemLoader(int wpt, ArrayRef order, int kOrder, + MMA16816SmemLoader(int wpt, ArrayRef order, uint32_t kOrder, ArrayRef tileShape, ArrayRef instrShape, ArrayRef matShape, int perPhase, int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter, TypeConverter *typeConverter, const Location &loc) - : wpt(wpt), order(order.begin(), order.end()), kOrder(kOrder), + : order(order.begin(), order.end()), kOrder(kOrder), tileShape(tileShape.begin(), tileShape.end()), instrShape(instrShape.begin(), instrShape.end()), matShape(matShape.begin(), matShape.end()), perPhase(perPhase), - maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), - typeConverter(typeConverter), loc(loc), ctx(rewriter.getContext()) { + maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), loc(loc), + ctx(rewriter.getContext()) { cMatShape = matShape[order[0]]; sMatShape = matShape[order[1]]; @@ -2576,7 +2559,6 @@ public: assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned"); int matIdx[2] = {mat0, mat1}; - int k = matIdx[kOrder]; int ptrIdx{-1}; @@ -2596,7 +2578,6 @@ public: Value ptr = getPtr(ptrIdx); - Value resV4; if (canUseLdmatrix) { int sOffset = matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes; @@ -2727,7 +2708,6 @@ public: } private: - int wpt; SmallVector order; int kOrder; SmallVector tileShape; @@ -2737,7 +2717,6 @@ private: int maxPhase; int elemBytes; ConversionPatternRewriter &rewriter; - TypeConverter *typeConverter{}; const Location &loc; MLIRContext *ctx{}; @@ -2786,14 +2765,9 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); // D = A * B + C Value A = op.a(); - Value B = op.b(); - Value C = op.c(); Value D = op.getResult(); - MLIRContext *ctx = op->getContext(); - bool allowTF32 = op.allowTF32(); // Here we assume the DotOp's operands always comes from shared memory. auto AShape = A.getType().cast().getShape(); @@ -2951,8 +2925,6 @@ struct DotOpConversionHelper { Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4); Type i8x4Pack4Ty = LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i8x4Ty)); - Type i32Pack4Ty = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(4, type::i32Ty(ctx))); switch (mmaType) { case TensorCoreType::FP32_FP16_FP16_FP32: @@ -3062,7 +3034,6 @@ struct DotOpConversionHelper { auto bTy = B.getType().cast(); // d = a*b + c auto dTy = op.d().getType().cast(); - auto mmaLayout = dTy.getEncoding().cast(); if (dTy.getElementType().isF32()) { if (aTy.getElementType().isF16() && bTy.getElementType().isF16()) @@ -3168,9 +3139,9 @@ struct MMA16816ConversionHelper { MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread, ConversionPatternRewriter &rewriter, TypeConverter *typeConverter, Location loc) - : mmaLayout(mmaLayout), helper(mmaLayout), rewriter(rewriter), - typeConverter(typeConverter), loc(loc), ctx(mmaLayout.getContext()), - thread(thread) { + : mmaLayout(mmaLayout), thread(thread), helper(mmaLayout), + rewriter(rewriter), typeConverter(typeConverter), loc(loc), + ctx(mmaLayout.getContext()) { wpt = mmaLayout.getWarpsPerCTA(); Value _32 = i32_val(32); @@ -3281,8 +3252,8 @@ struct MMA16816ConversionHelper { } // step1. Perform loading. - for (unsigned m = 0; m < numRepM; ++m) - for (unsigned k = 0; k < numRepK; ++k) + for (int m = 0; m < numRepM; ++m) + for (int k = 0; k < numRepK; ++k) loadFn(2 * m, 2 * k); // step2. Format the values to LLVM::Struct to passing to mma codegen. @@ -3305,8 +3276,8 @@ struct MMA16816ConversionHelper { 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/, {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); - for (unsigned n = 0; n < std::max(numRepN / 2, 1); ++n) { - for (unsigned k = 0; k < numRepK; ++k) + for (int n = 0; n < std::max(numRepN / 2, 1); ++n) { + for (int k = 0; k < numRepK; ++k) loadFn(2 * n, 2 * k); } @@ -3342,17 +3313,12 @@ struct MMA16816ConversionHelper { helper.deduceMmaType(op); auto aTensorTy = a.getType().cast(); - auto bTensorTy = b.getType().cast(); - auto cTensorTy = c.getType().cast(); auto dTensorTy = d.getType().cast(); auto aShape = aTensorTy.getShape(); auto dShape = dTensorTy.getShape(); - int NK = aShape[1]; // shape / shape_per_cta - auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy); - auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy); int numRepM = getNumRepM(aTensorTy, dShape[0]); int numRepN = getNumRepN(aTensorTy, dShape[1]); int numRepK = getNumRepK(aTensorTy, aShape[1]); @@ -3395,9 +3361,9 @@ struct MMA16816ConversionHelper { extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i)); }; - for (unsigned k = 0; k < numRepK; ++k) - for (unsigned m = 0; m < numRepM; ++m) - for (unsigned n = 0; n < numRepN; ++n) + for (int k = 0; k < numRepK; ++k) + for (int m = 0; m < numRepM; ++m) + for (int n = 0; n < numRepN; ++n) callMma(2 * m, n, 2 * k); // replace with new packed result @@ -3412,7 +3378,7 @@ struct MMA16816ConversionHelper { private: std::function getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout, - int wpt, int kOrder, ArrayRef instrShape, + int wpt, uint32_t kOrder, ArrayRef instrShape, ArrayRef matShape, Value warpId, ValueTable &vals) const { auto tensorTy = tensor.getType().cast(); @@ -3486,8 +3452,8 @@ private: Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0, int n1) const { std::vector elems; - for (unsigned m = 0; m < n0; ++m) - for (unsigned k = 0; k < n1; ++k) { + for (int m = 0; m < n0; ++m) + for (int k = 0; k < n1; ++k) { elems.push_back(vals.at({2 * m, 2 * k})); elems.push_back(vals.at({2 * m, 2 * k + 1})); elems.push_back(vals.at({2 * m + 1, 2 * k})); @@ -3529,10 +3495,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( auto loc = op.getLoc(); Value src = op.src(); Value dst = op.result(); - auto srcTensorTy = src.getType().cast(); auto dstTensorTy = dst.getType().cast(); - auto sharedLayout = srcTensorTy.getEncoding().cast(); auto dotOperandLayout = dstTensorTy.getEncoding().cast(); MmaEncodingAttr mmaLayout = @@ -3711,7 +3675,7 @@ struct AsyncWaitOpConversion auto ctx = op.getContext(); auto loc = op.getLoc(); auto voidTy = void_ty(ctx); - auto ret = ptxBuilder.launch(rewriter, loc, voidTy); + ptxBuilder.launch(rewriter, loc, voidTy); // Safe to remove the op since it doesn't have any return value. rewriter.eraseOp(op); @@ -3800,12 +3764,10 @@ struct InsertSliceAsyncOpConversion unsigned perPhase = resSharedLayout.getPerPhase(); unsigned maxPhase = resSharedLayout.getMaxPhase(); auto sizePerThread = srcBlockedLayout.getSizePerThread(); - auto threadsPerWarp = srcBlockedLayout.getThreadsPerWarp(); - auto warpsPerCTA = srcBlockedLayout.getWarpsPerCTA(); auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout); auto inOrder = srcBlockedLayout.getOrder(); - auto outOrder = resSharedLayout.getOrder(); + // If perPhase * maxPhase > threadsPerCTA, we need to swizzle over // elements across phases. If perPhase * maxPhase == threadsPerCTA, // swizzle is not allowd @@ -3886,7 +3848,7 @@ struct InsertSliceAsyncOpConversion auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8; auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}]; - for (unsigned wordIdx = 0; wordIdx < numWords; ++wordIdx) { + for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) { PTXBuilder ptxBuilder; auto wordElemIdx = wordIdx * numWordElems; auto ©AsyncOp = @@ -4208,7 +4170,7 @@ namespace mlir { TritonLLVMConversionTarget::TritonLLVMConversionTarget( MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter) - : ConversionTarget(ctx), typeConverter(typeConverter) { + : ConversionTarget(ctx) { addLegalDialect(); addLegalDialect(); // addIllegalDialect(); @@ -4220,7 +4182,7 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget( TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget( MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter) - : ConversionTarget(ctx), typeConverter(typeConverter) { + : ConversionTarget(ctx) { addLegalDialect(); // addLegalDialect(); addIllegalOp(); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 576ef735a..9519ff998 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -21,9 +21,7 @@ public: matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); - Op res = - rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()); - + rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()); return success(); } }; @@ -37,9 +35,8 @@ public: matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); - DstOp res = - rewriter.replaceOpWithNewOp(op, retType, adaptor.getPredicate(), - adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOpWithNewOp(op, retType, adaptor.getPredicate(), + adaptor.getLhs(), adaptor.getRhs()); return success(); } }; @@ -129,10 +126,9 @@ public: matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); - triton::gpu::SelectOp res = - rewriter.replaceOpWithNewOp( - op, retType, adaptor.getCondition(), adaptor.getTrueValue(), - adaptor.getFalseValue()); + rewriter.replaceOpWithNewOp( + op, retType, adaptor.getCondition(), adaptor.getTrueValue(), + adaptor.getFalseValue()); return success(); } }; @@ -204,9 +200,6 @@ struct TritonExpandDimsPattern triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread, retThreadsPerWarp, retWarpsPerCTA, retOrder); - // return type - RankedTensorType retType = - RankedTensorType::get(retShape, argType.getElementType(), retEncoding); // convert operand to slice of return type Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get( getContext(), op.axis(), retEncoding); @@ -252,7 +245,7 @@ struct TritonDotPattern : public OpConversionPattern { bType.getElementType(), encoding); b = rewriter.create(b.getLoc(), dstType, b); } - auto newDot = rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(), adaptor.transB()); return success(); @@ -279,7 +272,7 @@ struct TritonStorePattern : public OpConversionPattern { LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto newOp = rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, adaptor.ptr(), adaptor.value(), adaptor.mask()); return success(); } @@ -340,7 +333,7 @@ struct TritonReducePattern : public OpConversionPattern { LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto newOp = rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, adaptor.redOp(), adaptor.operand(), adaptor.axis()); return success(); } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 0a0b65406..1b570b289 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -49,7 +49,6 @@ unsigned getElemsPerThread(Type type) { auto tensorType = type.cast(); auto layout = tensorType.getEncoding(); auto shape = tensorType.getShape(); - size_t rank = shape.size(); if (auto blockedLayout = layout.dyn_cast()) { return blockedLayout.getElemsPerThread(shape); } else if (auto sliceLayout = layout.dyn_cast()) { @@ -109,7 +108,7 @@ SmallVector getThreadsPerCTA(const Attribute &layout) { SmallVector getShapePerCTA(const Attribute &layout) { SmallVector shape; if (auto blockedLayout = layout.dyn_cast()) { - for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d) + for (unsigned d = 0, n = blockedLayout.getOrder().size(); d < n; ++d) shape.push_back(blockedLayout.getSizePerThread()[d] * blockedLayout.getThreadsPerWarp()[d] * blockedLayout.getWarpsPerCTA()[d]); @@ -117,7 +116,7 @@ SmallVector getShapePerCTA(const Attribute &layout) { unsigned dim = sliceLayout.getDim(); auto parent = sliceLayout.getParent(); if (auto blockedParent = parent.dyn_cast()) { - for (int d = 0, n = blockedParent.getOrder().size(); d < n; ++d) { + for (unsigned d = 0, n = blockedParent.getOrder().size(); d < n; ++d) { if (d == dim) continue; shape.push_back(blockedParent.getSizePerThread()[d] * @@ -258,7 +257,6 @@ SliceEncodingAttr::paddedShape(ArrayRef shape) const { unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef shape) const { size_t rank = shape.size(); auto parent = getParent(); - unsigned dim = getDim(); if (auto blockedParent = parent.dyn_cast()) { assert(rank == blockedParent.getSizePerThread().size() - 1 && "unexpected rank in SliceEncodingAttr::getElemsPerThread"); @@ -512,11 +510,11 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes( auto encoding = srcType.getEncoding(); auto srcShape = srcType.getShape(); auto axis = attributes.get("axis").cast().getInt(); - if (axis < 0 || axis > srcShape.size()) + if (axis < 0 || (size_t)axis > srcShape.size()) return failure(); SmallVector dstShape; - for (int i = 0; i < srcShape.size(); i++) - if (i != axis) + for (size_t i = 0; i < srcShape.size(); i++) + if (i != (size_t)axis) dstShape.push_back(srcShape[i]); auto returnType = RankedTensorType::get(dstShape, srcType.getElementType(), encoding); @@ -578,15 +576,17 @@ struct TritonGPUInferLayoutInterface : public triton::DialectInferLayoutInterface { using DialectInferLayoutInterface::DialectInferLayoutInterface; - LogicalResult inferReduceOpEncoding(Attribute operandEncoding, int axis, - Attribute &resultEncoding) const { + LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding) const override { resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis, operandEncoding); return success(); } - LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, int axis, - Attribute &resultEncoding) const { + LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding) const override { auto sliceEncoding = operandEncoding.dyn_cast(); if (!sliceEncoding) { llvm::report_fatal_error( diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 0f363738f..36ea7030f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -87,7 +87,6 @@ public: if (!llvm::isa(op)) return mlir::failure(); auto convert = llvm::cast(op); - auto srcType = convert.getOperand().getType().cast(); auto dstType = convert.getType().cast(); // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accomodate fused attention @@ -219,10 +218,10 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op, auto typeInfer = dyn_cast(newOp); if (typeInfer) { SmallVector newType; - auto sucess = typeInfer.inferReturnTypes( + auto success = typeInfer.inferReturnTypes( newOp->getContext(), newOp->getLoc(), newOp->getOperands(), newOp->getAttrDictionary(), newOp->getRegions(), newType); - if (success) + if (succeeded(success)) newOp->getResult(0).setType(newType.front()); } return newOp; @@ -364,10 +363,6 @@ public: rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp, size_t i, RankedTensorType newType, triton::gpu::ConvertLayoutOp origConversion) const { - - auto newEncoding = newType.cast().getEncoding(); - auto ctx = forOp.getContext(); - auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; }; // Rewrite init argument Type origType = forOp.getInitArgs()[i].getType(); SmallVector newInitArgs = forOp.getInitArgs(); @@ -418,11 +413,10 @@ public: return newResults; } - mlir::LogicalResult matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const { - + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { auto forOp = cast(op); - auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; }; auto iterArgs = forOp.getRegionIterArgs(); for (auto iterArg : llvm::enumerate(iterArgs)) { // if (iterArg.index() != 1) @@ -480,7 +474,6 @@ public: auto forOp = dyn_cast(cvt->getParentOp()); if (!forOp) return mlir::failure(); - auto yieldOp = cast(forOp.getBody()->getTerminator()); auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; }; SetVector cvtSlices; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index ccb97aa52..abbef2efe 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -17,11 +17,6 @@ using namespace mlir; namespace { class LoopPipeliner { - /// comments on numStages: - /// [0, numStages-1) are in the prologue - /// numStages-1 is appended after the loop body - int numStages; - /// cache forOp we are working on scf::ForOp forOp; @@ -43,6 +38,11 @@ class LoopPipeliner { /// Value loopIterIdx; + /// comments on numStages: + /// [0, numStages-1) are in the prologue + /// numStages-1 is appended after the loop body + int numStages; + /// value (in loop) => value at stage N DenseMap> valueMapping; @@ -58,9 +58,6 @@ class LoopPipeliner { Value lookupOrDefault(Value origin, int stage); - /// return true if this op uses any of `loads` - bool isDirectUserOfAsyncLoad(Operation &op); - /// returns a empty buffer of size triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder); @@ -84,7 +81,7 @@ public: /// create the new ForOp (add new args & insert prefetched ops) scf::ForOp createNewForOp(); - friend class PipelinePass; + friend struct PipelinePass; }; // helpers @@ -123,19 +120,6 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet &deps) { } } -bool LoopPipeliner::isDirectUserOfAsyncLoad(Operation &op) { - for (Value loadOp : loads) { - assert(loadOp.hasOneUse() && - "load should only have one use (ConvertLayout)"); - Value loadUseResult = loadOp.getUsers().begin()->getResult(0); - for (Value opOperand : op.getOperands()) { - if (opOperand == loadUseResult) - return true; - } - } - return false; -} - triton::gpu::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) { // allocate a buffer for each pipelined tensor @@ -356,8 +340,8 @@ void LoopPipeliner::emitPrologue() { } // for (int stage = 0; stage < numStages - 1; ++stage) // async.wait & extract_slice - Operation *asyncWait = builder.create( - loads[0].getLoc(), loads.size() * (numStages - 2)); + builder.create(loads[0].getLoc(), + loads.size() * (numStages - 2)); loopIterIdx = builder.create(iv.getLoc(), 0, 32); for (Value loadOp : loads) { Value extractSlice = builder.create( @@ -380,8 +364,7 @@ void LoopPipeliner::emitEpilogue() { OpBuilder builder(forOp); OpBuilder::InsertionGuard g(builder); builder.setInsertionPointAfter(forOp); - Operation *asyncWait = - builder.create(forOp.getLoc(), 0); + builder.create(forOp.getLoc(), 0); } scf::ForOp LoopPipeliner::createNewForOp() { @@ -575,8 +558,8 @@ scf::ForOp LoopPipeliner::createNewForOp() { yieldValues.push_back(loopIterIdx); builder.setInsertionPointToEnd(newForOp.getBody()); - auto test = builder.create( - forOp.getBody()->getTerminator()->getLoc(), yieldValues); + builder.create(forOp.getBody()->getTerminator()->getLoc(), + yieldValues); return newForOp; } diff --git a/lib/Dialect/TritonGPU/Transforms/Swizzle.cpp b/lib/Dialect/TritonGPU/Transforms/Swizzle.cpp index 4ebe393ec..776fc9973 100644 --- a/lib/Dialect/TritonGPU/Transforms/Swizzle.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Swizzle.cpp @@ -30,7 +30,7 @@ struct SwizzlePass : public TritonGPUSwizzleBase { (ty.getElementType().getIntOrFloatBitWidth() / 8)); perPhase = std::max(perPhase, 1); // index of the inner dimension in `order` - int inner = (opIdx == 0) ? 0 : 1; + size_t inner = (opIdx == 0) ? 0 : 1; if (version == 1) { int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; // TODO: handle rep (see @@ -67,7 +67,6 @@ struct SwizzlePass : public TritonGPUSwizzleBase { void runOnOperation() override { Operation *op = getOperation(); - MLIRContext *context = &getContext(); op->walk([&](triton::DotOp dotOp) -> void { OpBuilder builder(dotOp); auto _retEncoding = diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 13f97d577..6bd11de81 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -73,7 +73,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // TritonGPUConversionTarget::TritonGPUConversionTarget( MLIRContext &context, TritonGPUTypeConverter &typeConverter) - : ConversionTarget(context), typeConverter(typeConverter) { + : ConversionTarget(context) { // TODO: we should also verify ops of TritonGPUDialect addLegalDialect(); @@ -90,7 +90,7 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( }); // We have requirements for the data layouts - addDynamicallyLegalOp([this](triton::DotOp dotOp) -> bool { + addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { Attribute aEncoding = dotOp.a().getType().cast().getEncoding(); Attribute bEncoding = diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp index 631af81cc..a8266322e 100644 --- a/lib/Target/PTX/PTXTranslation.cpp +++ b/lib/Target/PTX/PTXTranslation.cpp @@ -63,7 +63,7 @@ static bool find_and_replace(std::string &str, const std::string &begin, static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) { // LLVM version in use may not officially support target hardware int max_nvvm_cc = 75; - int max_nvvm_ptx = 74; + // int max_nvvm_ptx = 74; // options auto options = llvm::cl::getRegisteredOptions(); auto *short_ptr = diff --git a/python/setup.py b/python/setup.py index 8cf97888f..816a89397 100644 --- a/python/setup.py +++ b/python/setup.py @@ -1,11 +1,12 @@ import distutils -import distutils.spawn +import itertools import os import platform import re import shutil import subprocess import sys +import sysconfig import tarfile import tempfile import urllib.request @@ -16,6 +17,74 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext +# Logic from https://github.com/Kitware/CMake/blob/master/Modules/FindPythonLibs.cmake +# Code from https://stackoverflow.com/questions/47423246 +def get_python_library(): + """Get path to the python library associated with the current python + interpreter.""" + # determine direct path to libpython + python_version = sysconfig.get_python_version() + python_library = sysconfig.get_config_var('LIBRARY') + + # if static (or nonexistent), try to find a suitable dynamic libpython + if (python_library is None or os.path.splitext(python_library)[1][-2:] == '.a'): + + candidate_lib_prefixes = ['', 'lib'] + + candidate_extensions = ['.lib', '.so', '.a'] + if sysconfig.get_config_var('WITH_DYLD'): + candidate_extensions.insert(0, '.dylib') + + candidate_versions = [python_version] + if python_version: + candidate_versions.append('') + candidate_versions.insert( + 0, "".join(python_version.split(".")[:2])) + + abiflags = getattr(sys, 'abiflags', '') + candidate_abiflags = [abiflags] + if abiflags: + candidate_abiflags.append('') + + # Ensure the value injected by virtualenv is + # returned on windows. + # Because calling `sysconfig.get_config_var('multiarchsubdir')` + # returns an empty string on Linux, `du_sysconfig` is only used to + # get the value of `LIBDIR`. + libdir = distutils.sysconfig.get_config_var('LIBDIR') + if sysconfig.get_config_var('MULTIARCH'): + masd = sysconfig.get_config_var('multiarchsubdir') + if masd: + if masd.startswith(os.sep): + masd = masd[len(os.sep):] + libdir = os.path.join(libdir, masd) + + if libdir is None: + libdir = os.path.abspath(os.path.join( + sysconfig.get_config_var('LIBDEST'), "..", "libs")) + + candidates = ( + os.path.join( + libdir, + ''.join((pre, 'python', ver, abi, ext)) + ) + for (pre, ext, ver, abi) in itertools.product( + candidate_lib_prefixes, + candidate_extensions, + candidate_versions, + candidate_abiflags + ) + ) + + for candidate in candidates: + if os.path.exists(candidate): + # we found a (likely alternate) libpython + python_library = candidate + break + + return python_library + + # Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py def check_env_flag(name: str, default: str = "") -> bool: return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] @@ -136,14 +205,19 @@ class CMakeBuild(build_ext): if not os.path.exists(llvm_build_dir): os.makedirs(llvm_build_dir) # python directories - python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include'] + python_include_dir = distutils.sysconfig.get_python_inc() + python_link_dir = distutils.sysconfig.get_python_lib() + python_library = get_python_library() cmake_args = [ + "-DLLVM_ENABLE_WERROR=ON", "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_TUTORIALS=OFF", "-DTRITON_BUILD_PYTHON_MODULE=ON", # '-DPYTHON_EXECUTABLE=' + sys.executable, # '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', - "-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs), + "-DPYTHON_INCLUDE_DIRS=" + python_include_dir, + "-DPYTHON_LINK_DIRS=" + python_link_dir, + "-DPYTHON_LIBRARIES=" + python_library, "-DLLVM_EXTERNAL_LIT=" + lit_dir ] + thirdparty_cmake_args diff --git a/python/src/triton.cc b/python/src/triton.cc index a3d5a357e..243283006 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -26,6 +26,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/IRReader/IRReader.h" +#include "llvm/Support/FileUtilities.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Support/SourceMgr.h" @@ -1301,39 +1302,36 @@ void init_triton_translation(py::module &m) { py::gil_scoped_release allow_threads; // compile ptx with ptxas - char _fsrc[L_tmpnam]; - char _flog[L_tmpnam]; - std::tmpnam(_fsrc); - std::tmpnam(_flog); - std::string fsrc = _fsrc; - std::string flog = _flog; - std::string fbin = fsrc + ".o"; + llvm::SmallString<64> fsrc; + llvm::SmallString<64> flog; + llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc); + llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog); + std::string fbin = std::string(fsrc) + ".o"; + llvm::FileRemover srcRemover(fsrc); + llvm::FileRemover logRemover(flog); + llvm::FileRemover binRemover(fbin); + const char *_fsrc = fsrc.c_str(); + const char *_flog = flog.c_str(); const char *_fbin = fbin.c_str(); - std::ofstream ofs(fsrc); + std::ofstream ofs(_fsrc); ofs << ptxCode << std::endl; ofs.close(); std::string cmd; int err; cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) + - " " + fsrc + " -o " + fsrc + ".o 2> " + flog; + " " + _fsrc + " -o " + _fsrc + ".o 2> " + _flog; err = system(cmd.c_str()); if (err != 0) { std::ifstream _log(_flog); std::string log(std::istreambuf_iterator(_log), {}); - unlink(_fsrc); - unlink(_flog); throw std::runtime_error("Internal Triton PTX codegen error: \n" + log); } std::ifstream _cubin(_fbin, std::ios::binary); std::string cubin(std::istreambuf_iterator(_cubin), {}); _cubin.close(); - unlink(_fsrc); - unlink(_flog); - unlink(_fbin); - py::bytes bytes(cubin); - return bytes; + return std::move(bytes); }); m.def("add_external_libs", @@ -1345,8 +1343,8 @@ void init_triton_translation(py::module &m) { void init_triton(py::module &m) { py::module subm = m.def_submodule("triton"); - // init_triton_codegen(std::move(subm.def_submodule("code_gen"))); - init_triton_runtime(std::move(subm.def_submodule("runtime"))); - init_triton_ir(std::move(subm.def_submodule("ir"))); + // init_triton_codegen(subm.def_submodule("code_gen")); + init_triton_runtime(subm.def_submodule("runtime")); + init_triton_ir(subm.def_submodule("ir")); init_triton_translation(subm); }