Reverts back to MLIR 14 & updates CMakeLists

This commit is contained in:
Yan Da
2022-03-20 16:41:48 +08:00
parent a2c31ff434
commit 419bbe0f6e
14 changed files with 107 additions and 61 deletions

View File

@@ -141,19 +141,24 @@ if(BUILD_PYTHON_MODULE)
endif()
# Triton
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
if (WIN32 AND BUILD_PYTHON_MODULE)
find_package(Python3 REQUIRED COMPONENTS Development)
Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
set_target_properties(triton PROPERTIES PREFIX "lib")
else()
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
endif()
# # Triton
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
# if (WIN32 AND BUILD_PYTHON_MODULE)
# find_package(Python3 REQUIRED COMPONENTS Development)
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
# set_target_properties(triton PROPERTIES PREFIX "lib")
# else()
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# endif()
# MLIR
find_package(MLIR 14 REQUIRED CONFIG)
find_package(MLIR REQUIRED CONFIG PATHS ${LLVM_LIBRARY_DIR}/cmake/mlir)
list(APPEND CMAKE_MODULE_PATH ${LLVM_LIBRARY_DIR}/cmake/llvm)
list(APPEND CMAKE_MODULE_PATH ${LLVM_LIBRARY_DIR}/cmake/mlir)
include(TableGen) # required by AddMLIR
include(AddLLVM)
include(AddMLIR)
@@ -162,17 +167,17 @@ include(HandleLLVMOptions) # human-friendly error message
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
# link_directories(${LLVM_LIBRARY_DIR})
add_subdirectory(include)
add_subdirectory(lib)
# lib
add_library(triton)
# add_subdirectory(ir)
add_library(triton SHARED ${PYTHON_SRC})
target_link_libraries(triton
PUBLIC
TRITONIR
# # optimizations
# MLIRPass
# MLIRTransforms
TritonIR
TritonDriver
TritonCodeGen
)
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})

View File

@@ -1 +1 @@
add_subdirectory(ir)
add_subdirectory(triton/ir)

View File

@@ -4,15 +4,15 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
// #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "triton/Dialect.h.inc"
#include "triton/ir/Dialect.h.inc"
#include "triton/OpsEnums.h.inc"
#include "triton/ir/OpsEnums.h.inc"
#define GET_OP_CLASSES
#include "triton/Ops.h.inc"
#include "triton/ir/Ops.h.inc"
#endif // TRITON_IR_DIALECT_H_

View File

@@ -27,8 +27,11 @@ def Triton_Dialect : Dialect {
let dependentDialects = [
"arith::ArithmeticDialect",
"tensor::TensorDialect",
"cf::ControlFlowDialect",
"func::FuncDialect"
"StandardOpsDialect",
// Since LLVM 15
// "cf::ControlFlowDialect",
// "func::FuncDialect"
];
let extraClassDeclaration = [{

3
lib/CMakeLists.txt Normal file
View File

@@ -0,0 +1,3 @@
# add_subdirectory(codegen)
add_subdirectory(driver)
add_subdirectory(ir)

View File

@@ -0,0 +1,5 @@
file(GLOB_RECURSE CODEGEN_SRC *.cc)
add_library(TritonCodeGen
${CODEGEN_SRC}
)

View File

@@ -0,0 +1,5 @@
add_library(TritonDriver
dispatch.cc
error.cc
llvm.cc
)

View File

@@ -1,12 +1,34 @@
add_mlir_dialect_library(TRITONIR
add_mlir_dialect_library(TritonIR
Dialect.cpp
Ops.cpp
Types.cpp
DEPENDS
TritonTableGen
LINK_LIBS PUBLIC
MLIRIR
MLIRArithmetic
MLIRControlFlow
MLIRFunc
# Since LLVM 15
# MLIRControlFlow
# MLIRFunc
# else
MLIRStandard
MLIRTensor
)
# add_library(TritonIR
# Dialect.cpp
# Ops.cpp
# Types.cpp
# )
# target_link_libraries(TritonIR PUBLIC
# MLIRIR
# MLIRArithmetic
# MLIRControlFlow
# MLIRFunc
# MLIRTensor
# )

View File

@@ -1,5 +1,5 @@
#include "triton/Dialect.h"
#include "triton/Types.h"
#include "triton/ir/Dialect.h"
#include "triton/ir/Types.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -8,7 +8,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "triton/Dialect.cpp.inc"
#include "triton/ir/Dialect.cpp.inc"
using namespace mlir;
using namespace mlir::triton;
@@ -18,7 +18,7 @@ void TritonDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "triton/Ops.cpp.inc"
#include "triton/ir/Ops.cpp.inc"
>();
// We can also add interface here.

View File

@@ -1,15 +1,16 @@
#include "triton/ir/Dialect.h"
#include "triton/ir/Types.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OperationSupport.h"
#include "triton/Dialect.h"
#include "triton/Types.h"
#define GET_OP_CLASSES
#include "triton/Ops.cpp.inc"
#include "triton/ir/Ops.cpp.inc"
// enum attribute definitions
#include "triton/OpsEnums.cpp.inc"
#include "triton/ir/OpsEnums.cpp.inc"
namespace mlir {
namespace triton {

View File

@@ -1,5 +1,5 @@
#include "triton/Dialect.h"
#include "triton/Types.h"
#include "triton/ir/Dialect.h"
#include "triton/ir/Types.h"
using namespace mlir;
using namespace mlir::triton;

View File

@@ -17,7 +17,7 @@ from setuptools.command.build_ext import build_ext
def get_llvm():
# tries to find system LLVM
versions = ['-11.0', '-11', '-11-64']
versions = ['-14.0', '-14', '-14-64']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None]
@@ -26,7 +26,8 @@ def get_llvm():
if platform.system() == "Windows":
return '', ''
# download if nothing is installed
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
# name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
name = 'clang+llvm-14.0.0-rc2-x86_64-linux-gnu-ubuntu-18.04'
dir = '/tmp'
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
@@ -35,7 +36,7 @@ def get_llvm():
shutil.rmtree(os.path.join(dir, name))
except Exception:
pass
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0-rc2/{name}.tar.xz".format(name=name)
print('downloading and extracting ' + url + '...')
ftpstream = urllib.request.urlopen(url)
file = tarfile.open(fileobj=ftpstream, mode="r|xz")

View File

@@ -36,13 +36,13 @@ std::vector<int> segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width
std::vector<int> current(H, 0);
int num = 0;
std::vector<int> lut(H * M * N * 4);
for (size_t h = 0; h < H; h++) {
for (ssize_t h = 0; h < H; h++) {
// surrounding indices
std::vector<int> ii_left(max_width, -1);
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
// start the dynamic programming algorithm
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
for (ssize_t m = 0; m < M; m++) {
for (ssize_t n = 0; n < N; n++) {
int v = layout(h, m, n);
if (v == 0)
continue;
@@ -70,8 +70,8 @@ std::vector<int> segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width
if (width != max_width)
continue;
// retained blocks are set to zeros
for (size_t km = 0; km < max_width; km++)
for (size_t kn = 0; kn < max_width; kn++) {
for (ssize_t km = 0; km < max_width; km++)
for (ssize_t kn = 0; kn < max_width; kn++) {
int mm = ii_top[km][n];
int nn = ii_left[kn];
if (mm < 0 || nn < 0)

View File

@@ -2,11 +2,15 @@
#include "triton/codegen/target.h"
#include "triton/driver/error.h"
#include "triton/driver/llvm.h"
#include "triton/ir/builder.h"
#include "triton/ir/enums.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/print.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
#include <optional>
#include <pybind11/buffer_info.h>
#include <pybind11/functional.h>
@@ -18,9 +22,6 @@
#include <sstream>
#include <stdexcept>
#include <string>
#include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
namespace py = pybind11;
namespace ir = triton::ir;
@@ -748,7 +749,7 @@ void init_triton_ir(py::module &&m) {
}, ret::reference)
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
.def(py::init<ir::context &>())
// getters
.def_property_readonly("context", &ir::builder::get_context, ret::reference)
@@ -788,10 +789,10 @@ void init_triton_ir(py::module &&m) {
.def("get_range", &ir::builder::get_range, ret::reference)
// Types
.def("get_void_ty", &ir::builder::get_void_ty, ret::reference)
.def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference)
.def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference)
.def("get_int1_ty", &mlir::OpBuilder::getI1Type, ret::reference)
.def("get_int8_ty", &mlir::OpBuilder::getI8Type, ret::reference)
.def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference)
.def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference)
.def("get_int32_ty", &mlir::OpBuilder::getI32Type, ret::reference)
.def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference)
.def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference)
.def("get_half_ty", &ir::builder::get_half_ty, ret::reference)