diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b04f69d6..7971d3aed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,7 +37,11 @@ if(WIN32) add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32) endif() -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") +set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden") +if(APPLE) + set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6) +endif() + ########## diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 127f8366e..1413ea8fb 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -11,7 +11,7 @@ using namespace mlir::triton::gpu; // parse an array of integers static LogicalResult parseIntArrayAttr(AsmParser &parser, const NamedAttribute &attr, - /*SmallVector*/ auto &res, + SmallVector &res, StringRef desc) { auto arrayAttr = attr.getValue().dyn_cast(); if (!arrayAttr) { @@ -84,7 +84,8 @@ static Attribute parseBlocked(AsmParser &parser, Type type) { broadcastAxis); } -static void printBlocked(AsmPrinter &printer, auto *attr) { +template +static void printBlocked(AsmPrinter &printer, const T *attr) { printer << "<{" << "threadTileSize = [" << attr->getThreadTileSize() << "]" << ", warpTileSize = [" << attr->getWarpTileSize() << "]" @@ -95,7 +96,7 @@ static void printBlocked(AsmPrinter &printer, auto *attr) { } Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) { - parseBlocked(parser, type); + return parseBlocked(parser, type); } void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { @@ -104,7 +105,7 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) { - parseBlocked(parser, type); + return parseBlocked(parser, type); } void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const { @@ -163,7 +164,7 @@ static Attribute parseMma(AsmParser &parser, Type type) { shapePerTile, repetitions, contigPerThread, broadcastAxis); } -static void printMma(AsmPrinter &printer, auto *attr) { +template static void printMma(AsmPrinter &printer, T *attr) { printer << "<{" << "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]" << ", shapePerWarp = [" << attr->getShapePerWarp() << "]" @@ -276,12 +277,14 @@ public: attr.dyn_cast()) { os << "blocked_multicast"; TritonGPUOpAsmInterface::printBlocked(blockedMulticastAttr, os); + return AliasResult::FinalAlias; } OpAsmDialectInterface::getAlias(attr, os); + return AliasResult::FinalAlias; } private: - static void printMma(const auto &attr, raw_ostream &os) { + static void printMma(const TritonGPUMmaEncodingAttr &attr, raw_ostream &os) { TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os); TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os); TritonGPUOpAsmInterface::printArray(attr.getWarpPerTile(), os); @@ -290,14 +293,15 @@ private: TritonGPUOpAsmInterface::printArray(attr.getContigPerThread(), os); } - static void printShared(const auto &attr, raw_ostream &os) { + static void printShared(const TritonGPUSharedEncodingAttr &attr, + raw_ostream &os) { os << "_" << attr.getVec(); os << "_" << attr.getPerPhase(); os << "_" << attr.getMaxPhase(); TritonGPUOpAsmInterface::printArray(attr.getOrder(), os); } - static void printBlocked(const auto &attr, raw_ostream &os) { + template static void printBlocked(const T &attr, raw_ostream &os) { TritonGPUOpAsmInterface::printArray(attr.getThreadTileSize(), os); TritonGPUOpAsmInterface::printArray(attr.getWarpTileSize(), os); TritonGPUOpAsmInterface::printArray(attr.getBlockTileSize(), os); @@ -305,7 +309,7 @@ private: TritonGPUOpAsmInterface::printArray(attr.getBroadcastAxis(), os); } - static void printArray(const auto &array, raw_ostream &os, + static void printArray(const ArrayRef &array, raw_ostream &os, const std::string &delimiter = "x") { os << "_"; if (array.empty()) { diff --git a/python/setup.py b/python/setup.py index 1e17d26f2..7a90f0951 100644 --- a/python/setup.py +++ b/python/setup.py @@ -17,7 +17,9 @@ from setuptools.command.build_ext import build_ext def get_llvm(): # download if nothing is installed - name = 'clang+llvm-14.0.0-x86_64-linux-gnu-ubuntu-18.04' + system = platform.system() + suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system] + name = f'clang+llvm-14.0.0-x86_64-{suffix}' dir = '/tmp' llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name) llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)