Reverts back to MLIR 14 & updates CMakeLists
This commit is contained in:
@@ -141,19 +141,24 @@ if(BUILD_PYTHON_MODULE)
|
||||
endif()
|
||||
|
||||
|
||||
# Triton
|
||||
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||
if (WIN32 AND BUILD_PYTHON_MODULE)
|
||||
find_package(Python3 REQUIRED COMPONENTS Development)
|
||||
Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
|
||||
set_target_properties(triton PROPERTIES PREFIX "lib")
|
||||
else()
|
||||
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
endif()
|
||||
# # Triton
|
||||
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||
# if (WIN32 AND BUILD_PYTHON_MODULE)
|
||||
# find_package(Python3 REQUIRED COMPONENTS Development)
|
||||
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
|
||||
# set_target_properties(triton PROPERTIES PREFIX "lib")
|
||||
# else()
|
||||
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
# endif()
|
||||
|
||||
|
||||
# MLIR
|
||||
find_package(MLIR 14 REQUIRED CONFIG)
|
||||
find_package(MLIR REQUIRED CONFIG PATHS ${LLVM_LIBRARY_DIR}/cmake/mlir)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${LLVM_LIBRARY_DIR}/cmake/llvm)
|
||||
list(APPEND CMAKE_MODULE_PATH ${LLVM_LIBRARY_DIR}/cmake/mlir)
|
||||
|
||||
include(TableGen) # required by AddMLIR
|
||||
include(AddLLVM)
|
||||
include(AddMLIR)
|
||||
@@ -162,17 +167,17 @@ include(HandleLLVMOptions) # human-friendly error message
|
||||
include_directories(${MLIR_INCLUDE_DIRS})
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
|
||||
# link_directories(${LLVM_LIBRARY_DIR})
|
||||
|
||||
add_subdirectory(include)
|
||||
add_subdirectory(lib)
|
||||
# lib
|
||||
add_library(triton)
|
||||
# add_subdirectory(ir)
|
||||
|
||||
add_library(triton SHARED ${PYTHON_SRC})
|
||||
|
||||
target_link_libraries(triton
|
||||
PUBLIC
|
||||
TRITONIR
|
||||
# # optimizations
|
||||
# MLIRPass
|
||||
# MLIRTransforms
|
||||
TritonIR
|
||||
TritonDriver
|
||||
TritonCodeGen
|
||||
)
|
||||
|
||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||
|
@@ -1 +1 @@
|
||||
add_subdirectory(ir)
|
||||
add_subdirectory(triton/ir)
|
||||
|
@@ -4,15 +4,15 @@
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
// #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
|
||||
#include "triton/Dialect.h.inc"
|
||||
#include "triton/ir/Dialect.h.inc"
|
||||
|
||||
#include "triton/OpsEnums.h.inc"
|
||||
#include "triton/ir/OpsEnums.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Ops.h.inc"
|
||||
#include "triton/ir/Ops.h.inc"
|
||||
|
||||
#endif // TRITON_IR_DIALECT_H_
|
||||
|
@@ -27,8 +27,11 @@ def Triton_Dialect : Dialect {
|
||||
let dependentDialects = [
|
||||
"arith::ArithmeticDialect",
|
||||
"tensor::TensorDialect",
|
||||
"cf::ControlFlowDialect",
|
||||
"func::FuncDialect"
|
||||
"StandardOpsDialect",
|
||||
|
||||
// Since LLVM 15
|
||||
// "cf::ControlFlowDialect",
|
||||
// "func::FuncDialect"
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
3
lib/CMakeLists.txt
Normal file
3
lib/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
# add_subdirectory(codegen)
|
||||
add_subdirectory(driver)
|
||||
add_subdirectory(ir)
|
5
lib/codegen/CMakeLists.txt
Normal file
5
lib/codegen/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
file(GLOB_RECURSE CODEGEN_SRC *.cc)
|
||||
|
||||
add_library(TritonCodeGen
|
||||
${CODEGEN_SRC}
|
||||
)
|
5
lib/driver/CMakeLists.txt
Normal file
5
lib/driver/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
add_library(TritonDriver
|
||||
dispatch.cc
|
||||
error.cc
|
||||
llvm.cc
|
||||
)
|
@@ -1,12 +1,34 @@
|
||||
add_mlir_dialect_library(TRITONIR
|
||||
add_mlir_dialect_library(TritonIR
|
||||
Dialect.cpp
|
||||
Ops.cpp
|
||||
Types.cpp
|
||||
|
||||
DEPENDS
|
||||
TritonTableGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRArithmetic
|
||||
MLIRControlFlow
|
||||
MLIRFunc
|
||||
|
||||
# Since LLVM 15
|
||||
# MLIRControlFlow
|
||||
# MLIRFunc
|
||||
# else
|
||||
MLIRStandard
|
||||
|
||||
MLIRTensor
|
||||
)
|
||||
|
||||
# add_library(TritonIR
|
||||
# Dialect.cpp
|
||||
# Ops.cpp
|
||||
# Types.cpp
|
||||
# )
|
||||
|
||||
# target_link_libraries(TritonIR PUBLIC
|
||||
# MLIRIR
|
||||
# MLIRArithmetic
|
||||
# MLIRControlFlow
|
||||
# MLIRFunc
|
||||
# MLIRTensor
|
||||
# )
|
||||
|
@@ -1,5 +1,5 @@
|
||||
#include "triton/Dialect.h"
|
||||
#include "triton/Types.h"
|
||||
#include "triton/ir/Dialect.h"
|
||||
#include "triton/ir/Types.h"
|
||||
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
|
||||
|
||||
#include "triton/Dialect.cpp.inc"
|
||||
#include "triton/ir/Dialect.cpp.inc"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
@@ -18,7 +18,7 @@ void TritonDialect::initialize() {
|
||||
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "triton/Ops.cpp.inc"
|
||||
#include "triton/ir/Ops.cpp.inc"
|
||||
>();
|
||||
|
||||
// We can also add interface here.
|
||||
|
@@ -1,15 +1,16 @@
|
||||
#include "triton/ir/Dialect.h"
|
||||
#include "triton/ir/Types.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "triton/Dialect.h"
|
||||
#include "triton/Types.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Ops.cpp.inc"
|
||||
#include "triton/ir/Ops.cpp.inc"
|
||||
|
||||
// enum attribute definitions
|
||||
#include "triton/OpsEnums.cpp.inc"
|
||||
#include "triton/ir/OpsEnums.cpp.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
@@ -1,5 +1,5 @@
|
||||
#include "triton/Dialect.h"
|
||||
#include "triton/Types.h"
|
||||
#include "triton/ir/Dialect.h"
|
||||
#include "triton/ir/Types.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
@@ -17,7 +17,7 @@ from setuptools.command.build_ext import build_ext
|
||||
|
||||
def get_llvm():
|
||||
# tries to find system LLVM
|
||||
versions = ['-11.0', '-11', '-11-64']
|
||||
versions = ['-14.0', '-14', '-14-64']
|
||||
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
||||
paths = [p for p in paths if p is not None]
|
||||
@@ -26,7 +26,8 @@ def get_llvm():
|
||||
if platform.system() == "Windows":
|
||||
return '', ''
|
||||
# download if nothing is installed
|
||||
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
|
||||
# name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
|
||||
name = 'clang+llvm-14.0.0-rc2-x86_64-linux-gnu-ubuntu-18.04'
|
||||
dir = '/tmp'
|
||||
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
|
||||
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
|
||||
@@ -35,7 +36,7 @@ def get_llvm():
|
||||
shutil.rmtree(os.path.join(dir, name))
|
||||
except Exception:
|
||||
pass
|
||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
|
||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0-rc2/{name}.tar.xz".format(name=name)
|
||||
print('downloading and extracting ' + url + '...')
|
||||
ftpstream = urllib.request.urlopen(url)
|
||||
file = tarfile.open(fileobj=ftpstream, mode="r|xz")
|
||||
|
@@ -36,13 +36,13 @@ std::vector<int> segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width
|
||||
std::vector<int> current(H, 0);
|
||||
int num = 0;
|
||||
std::vector<int> lut(H * M * N * 4);
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
for (ssize_t h = 0; h < H; h++) {
|
||||
// surrounding indices
|
||||
std::vector<int> ii_left(max_width, -1);
|
||||
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
|
||||
// start the dynamic programming algorithm
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
for (ssize_t m = 0; m < M; m++) {
|
||||
for (ssize_t n = 0; n < N; n++) {
|
||||
int v = layout(h, m, n);
|
||||
if (v == 0)
|
||||
continue;
|
||||
@@ -70,8 +70,8 @@ std::vector<int> segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width
|
||||
if (width != max_width)
|
||||
continue;
|
||||
// retained blocks are set to zeros
|
||||
for (size_t km = 0; km < max_width; km++)
|
||||
for (size_t kn = 0; kn < max_width; kn++) {
|
||||
for (ssize_t km = 0; km < max_width; km++)
|
||||
for (ssize_t kn = 0; kn < max_width; kn++) {
|
||||
int mm = ii_top[km][n];
|
||||
int nn = ii_left[kn];
|
||||
if (mm < 0 || nn < 0)
|
||||
@@ -116,4 +116,4 @@ std::vector<lut_t> superblock(uintptr_t LAYOUT, int H, int M, int N, int start_w
|
||||
|
||||
void init_superblocking(pybind11::module &m) {
|
||||
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
|
||||
}
|
||||
}
|
||||
|
@@ -2,11 +2,15 @@
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/enums.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/print.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
|
||||
#include <optional>
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/functional.h>
|
||||
@@ -18,9 +22,6 @@
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace ir = triton::ir;
|
||||
@@ -748,7 +749,7 @@ void init_triton_ir(py::module &&m) {
|
||||
}, ret::reference)
|
||||
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
|
||||
|
||||
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
|
||||
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
|
||||
.def(py::init<ir::context &>())
|
||||
// getters
|
||||
.def_property_readonly("context", &ir::builder::get_context, ret::reference)
|
||||
@@ -788,10 +789,10 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_range", &ir::builder::get_range, ret::reference)
|
||||
// Types
|
||||
.def("get_void_ty", &ir::builder::get_void_ty, ret::reference)
|
||||
.def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference)
|
||||
.def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference)
|
||||
.def("get_int1_ty", &mlir::OpBuilder::getI1Type, ret::reference)
|
||||
.def("get_int8_ty", &mlir::OpBuilder::getI8Type, ret::reference)
|
||||
.def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference)
|
||||
.def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference)
|
||||
.def("get_int32_ty", &mlir::OpBuilder::getI32Type, ret::reference)
|
||||
.def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference)
|
||||
.def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference)
|
||||
.def("get_half_ty", &ir::builder::get_half_ty, ret::reference)
|
||||
|
Reference in New Issue
Block a user