Reverts back to MLIR 14 & updates CMakeLists
This commit is contained in:
@@ -141,19 +141,24 @@ if(BUILD_PYTHON_MODULE)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
# Triton
|
# # Triton
|
||||||
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||||
if (WIN32 AND BUILD_PYTHON_MODULE)
|
# if (WIN32 AND BUILD_PYTHON_MODULE)
|
||||||
find_package(Python3 REQUIRED COMPONENTS Development)
|
# find_package(Python3 REQUIRED COMPONENTS Development)
|
||||||
Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||||
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
|
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
|
||||||
set_target_properties(triton PROPERTIES PREFIX "lib")
|
# set_target_properties(triton PROPERTIES PREFIX "lib")
|
||||||
else()
|
# else()
|
||||||
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||||
endif()
|
# endif()
|
||||||
|
|
||||||
|
|
||||||
# MLIR
|
# MLIR
|
||||||
find_package(MLIR 14 REQUIRED CONFIG)
|
find_package(MLIR REQUIRED CONFIG PATHS ${LLVM_LIBRARY_DIR}/cmake/mlir)
|
||||||
|
|
||||||
|
list(APPEND CMAKE_MODULE_PATH ${LLVM_LIBRARY_DIR}/cmake/llvm)
|
||||||
|
list(APPEND CMAKE_MODULE_PATH ${LLVM_LIBRARY_DIR}/cmake/mlir)
|
||||||
|
|
||||||
include(TableGen) # required by AddMLIR
|
include(TableGen) # required by AddMLIR
|
||||||
include(AddLLVM)
|
include(AddLLVM)
|
||||||
include(AddMLIR)
|
include(AddMLIR)
|
||||||
@@ -162,17 +167,17 @@ include(HandleLLVMOptions) # human-friendly error message
|
|||||||
include_directories(${MLIR_INCLUDE_DIRS})
|
include_directories(${MLIR_INCLUDE_DIRS})
|
||||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||||
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
|
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
|
||||||
|
# link_directories(${LLVM_LIBRARY_DIR})
|
||||||
|
|
||||||
|
add_subdirectory(include)
|
||||||
add_subdirectory(lib)
|
add_subdirectory(lib)
|
||||||
# lib
|
|
||||||
add_library(triton)
|
add_library(triton SHARED ${PYTHON_SRC})
|
||||||
# add_subdirectory(ir)
|
|
||||||
target_link_libraries(triton
|
target_link_libraries(triton
|
||||||
PUBLIC
|
TritonIR
|
||||||
TRITONIR
|
TritonDriver
|
||||||
# # optimizations
|
TritonCodeGen
|
||||||
# MLIRPass
|
|
||||||
# MLIRTransforms
|
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||||
|
@@ -1 +1 @@
|
|||||||
add_subdirectory(ir)
|
add_subdirectory(triton/ir)
|
||||||
|
@@ -4,15 +4,15 @@
|
|||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
// #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||||
|
|
||||||
#include "triton/Dialect.h.inc"
|
#include "triton/ir/Dialect.h.inc"
|
||||||
|
|
||||||
#include "triton/OpsEnums.h.inc"
|
#include "triton/ir/OpsEnums.h.inc"
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "triton/Ops.h.inc"
|
#include "triton/ir/Ops.h.inc"
|
||||||
|
|
||||||
#endif // TRITON_IR_DIALECT_H_
|
#endif // TRITON_IR_DIALECT_H_
|
||||||
|
@@ -27,8 +27,11 @@ def Triton_Dialect : Dialect {
|
|||||||
let dependentDialects = [
|
let dependentDialects = [
|
||||||
"arith::ArithmeticDialect",
|
"arith::ArithmeticDialect",
|
||||||
"tensor::TensorDialect",
|
"tensor::TensorDialect",
|
||||||
"cf::ControlFlowDialect",
|
"StandardOpsDialect",
|
||||||
"func::FuncDialect"
|
|
||||||
|
// Since LLVM 15
|
||||||
|
// "cf::ControlFlowDialect",
|
||||||
|
// "func::FuncDialect"
|
||||||
];
|
];
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
|
3
lib/CMakeLists.txt
Normal file
3
lib/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# add_subdirectory(codegen)
|
||||||
|
add_subdirectory(driver)
|
||||||
|
add_subdirectory(ir)
|
5
lib/codegen/CMakeLists.txt
Normal file
5
lib/codegen/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
file(GLOB_RECURSE CODEGEN_SRC *.cc)
|
||||||
|
|
||||||
|
add_library(TritonCodeGen
|
||||||
|
${CODEGEN_SRC}
|
||||||
|
)
|
5
lib/driver/CMakeLists.txt
Normal file
5
lib/driver/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
add_library(TritonDriver
|
||||||
|
dispatch.cc
|
||||||
|
error.cc
|
||||||
|
llvm.cc
|
||||||
|
)
|
@@ -1,12 +1,34 @@
|
|||||||
add_mlir_dialect_library(TRITONIR
|
add_mlir_dialect_library(TritonIR
|
||||||
Dialect.cpp
|
Dialect.cpp
|
||||||
Ops.cpp
|
Ops.cpp
|
||||||
Types.cpp
|
Types.cpp
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
TritonTableGen
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRArithmetic
|
MLIRArithmetic
|
||||||
MLIRControlFlow
|
|
||||||
MLIRFunc
|
# Since LLVM 15
|
||||||
|
# MLIRControlFlow
|
||||||
|
# MLIRFunc
|
||||||
|
# else
|
||||||
|
MLIRStandard
|
||||||
|
|
||||||
MLIRTensor
|
MLIRTensor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# add_library(TritonIR
|
||||||
|
# Dialect.cpp
|
||||||
|
# Ops.cpp
|
||||||
|
# Types.cpp
|
||||||
|
# )
|
||||||
|
|
||||||
|
# target_link_libraries(TritonIR PUBLIC
|
||||||
|
# MLIRIR
|
||||||
|
# MLIRArithmetic
|
||||||
|
# MLIRControlFlow
|
||||||
|
# MLIRFunc
|
||||||
|
# MLIRTensor
|
||||||
|
# )
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
#include "triton/Dialect.h"
|
#include "triton/ir/Dialect.h"
|
||||||
#include "triton/Types.h"
|
#include "triton/ir/Types.h"
|
||||||
|
|
||||||
#include "llvm/ADT/StringSwitch.h"
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
@@ -8,7 +8,7 @@
|
|||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
|
|
||||||
|
|
||||||
#include "triton/Dialect.cpp.inc"
|
#include "triton/ir/Dialect.cpp.inc"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
@@ -18,7 +18,7 @@ void TritonDialect::initialize() {
|
|||||||
|
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "triton/Ops.cpp.inc"
|
#include "triton/ir/Ops.cpp.inc"
|
||||||
>();
|
>();
|
||||||
|
|
||||||
// We can also add interface here.
|
// We can also add interface here.
|
||||||
|
@@ -1,15 +1,16 @@
|
|||||||
|
#include "triton/ir/Dialect.h"
|
||||||
|
#include "triton/ir/Types.h"
|
||||||
|
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/OperationSupport.h"
|
#include "mlir/IR/OperationSupport.h"
|
||||||
#include "triton/Dialect.h"
|
|
||||||
#include "triton/Types.h"
|
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "triton/Ops.cpp.inc"
|
#include "triton/ir/Ops.cpp.inc"
|
||||||
|
|
||||||
// enum attribute definitions
|
// enum attribute definitions
|
||||||
#include "triton/OpsEnums.cpp.inc"
|
#include "triton/ir/OpsEnums.cpp.inc"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
#include "triton/Dialect.h"
|
#include "triton/ir/Dialect.h"
|
||||||
#include "triton/Types.h"
|
#include "triton/ir/Types.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
|
@@ -17,7 +17,7 @@ from setuptools.command.build_ext import build_ext
|
|||||||
|
|
||||||
def get_llvm():
|
def get_llvm():
|
||||||
# tries to find system LLVM
|
# tries to find system LLVM
|
||||||
versions = ['-11.0', '-11', '-11-64']
|
versions = ['-14.0', '-14', '-14-64']
|
||||||
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
||||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
||||||
paths = [p for p in paths if p is not None]
|
paths = [p for p in paths if p is not None]
|
||||||
@@ -26,7 +26,8 @@ def get_llvm():
|
|||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
return '', ''
|
return '', ''
|
||||||
# download if nothing is installed
|
# download if nothing is installed
|
||||||
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
|
# name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
|
||||||
|
name = 'clang+llvm-14.0.0-rc2-x86_64-linux-gnu-ubuntu-18.04'
|
||||||
dir = '/tmp'
|
dir = '/tmp'
|
||||||
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
|
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
|
||||||
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
|
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
|
||||||
@@ -35,7 +36,7 @@ def get_llvm():
|
|||||||
shutil.rmtree(os.path.join(dir, name))
|
shutil.rmtree(os.path.join(dir, name))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
|
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0-rc2/{name}.tar.xz".format(name=name)
|
||||||
print('downloading and extracting ' + url + '...')
|
print('downloading and extracting ' + url + '...')
|
||||||
ftpstream = urllib.request.urlopen(url)
|
ftpstream = urllib.request.urlopen(url)
|
||||||
file = tarfile.open(fileobj=ftpstream, mode="r|xz")
|
file = tarfile.open(fileobj=ftpstream, mode="r|xz")
|
||||||
|
@@ -36,13 +36,13 @@ std::vector<int> segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width
|
|||||||
std::vector<int> current(H, 0);
|
std::vector<int> current(H, 0);
|
||||||
int num = 0;
|
int num = 0;
|
||||||
std::vector<int> lut(H * M * N * 4);
|
std::vector<int> lut(H * M * N * 4);
|
||||||
for (size_t h = 0; h < H; h++) {
|
for (ssize_t h = 0; h < H; h++) {
|
||||||
// surrounding indices
|
// surrounding indices
|
||||||
std::vector<int> ii_left(max_width, -1);
|
std::vector<int> ii_left(max_width, -1);
|
||||||
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
|
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
|
||||||
// start the dynamic programming algorithm
|
// start the dynamic programming algorithm
|
||||||
for (size_t m = 0; m < M; m++) {
|
for (ssize_t m = 0; m < M; m++) {
|
||||||
for (size_t n = 0; n < N; n++) {
|
for (ssize_t n = 0; n < N; n++) {
|
||||||
int v = layout(h, m, n);
|
int v = layout(h, m, n);
|
||||||
if (v == 0)
|
if (v == 0)
|
||||||
continue;
|
continue;
|
||||||
@@ -70,8 +70,8 @@ std::vector<int> segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width
|
|||||||
if (width != max_width)
|
if (width != max_width)
|
||||||
continue;
|
continue;
|
||||||
// retained blocks are set to zeros
|
// retained blocks are set to zeros
|
||||||
for (size_t km = 0; km < max_width; km++)
|
for (ssize_t km = 0; km < max_width; km++)
|
||||||
for (size_t kn = 0; kn < max_width; kn++) {
|
for (ssize_t kn = 0; kn < max_width; kn++) {
|
||||||
int mm = ii_top[km][n];
|
int mm = ii_top[km][n];
|
||||||
int nn = ii_left[kn];
|
int nn = ii_left[kn];
|
||||||
if (mm < 0 || nn < 0)
|
if (mm < 0 || nn < 0)
|
||||||
@@ -116,4 +116,4 @@ std::vector<lut_t> superblock(uintptr_t LAYOUT, int H, int M, int N, int start_w
|
|||||||
|
|
||||||
void init_superblocking(pybind11::module &m) {
|
void init_superblocking(pybind11::module &m) {
|
||||||
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
|
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
|
||||||
}
|
}
|
||||||
|
@@ -2,11 +2,15 @@
|
|||||||
#include "triton/codegen/target.h"
|
#include "triton/codegen/target.h"
|
||||||
#include "triton/driver/error.h"
|
#include "triton/driver/error.h"
|
||||||
#include "triton/driver/llvm.h"
|
#include "triton/driver/llvm.h"
|
||||||
#include "triton/ir/builder.h"
|
|
||||||
#include "triton/ir/enums.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "triton/ir/function.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "triton/ir/module.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "triton/ir/print.h"
|
|
||||||
|
#include "llvm/IR/Module.h"
|
||||||
|
#include "llvm/IR/LegacyPassManager.h"
|
||||||
|
#include "llvm/IR/Verifier.h"
|
||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <pybind11/buffer_info.h>
|
#include <pybind11/buffer_info.h>
|
||||||
#include <pybind11/functional.h>
|
#include <pybind11/functional.h>
|
||||||
@@ -18,9 +22,6 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "llvm/IR/Module.h"
|
|
||||||
#include "llvm/IR/LegacyPassManager.h"
|
|
||||||
#include "llvm/IR/Verifier.h"
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
namespace ir = triton::ir;
|
namespace ir = triton::ir;
|
||||||
@@ -748,7 +749,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
}, ret::reference)
|
}, ret::reference)
|
||||||
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
|
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
|
||||||
|
|
||||||
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
|
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
|
||||||
.def(py::init<ir::context &>())
|
.def(py::init<ir::context &>())
|
||||||
// getters
|
// getters
|
||||||
.def_property_readonly("context", &ir::builder::get_context, ret::reference)
|
.def_property_readonly("context", &ir::builder::get_context, ret::reference)
|
||||||
@@ -788,10 +789,10 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("get_range", &ir::builder::get_range, ret::reference)
|
.def("get_range", &ir::builder::get_range, ret::reference)
|
||||||
// Types
|
// Types
|
||||||
.def("get_void_ty", &ir::builder::get_void_ty, ret::reference)
|
.def("get_void_ty", &ir::builder::get_void_ty, ret::reference)
|
||||||
.def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference)
|
.def("get_int1_ty", &mlir::OpBuilder::getI1Type, ret::reference)
|
||||||
.def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference)
|
.def("get_int8_ty", &mlir::OpBuilder::getI8Type, ret::reference)
|
||||||
.def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference)
|
.def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference)
|
||||||
.def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference)
|
.def("get_int32_ty", &mlir::OpBuilder::getI32Type, ret::reference)
|
||||||
.def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference)
|
.def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference)
|
||||||
.def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference)
|
.def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference)
|
||||||
.def("get_half_ty", &ir::builder::get_half_ty, ret::reference)
|
.def("get_half_ty", &ir::builder::get_half_ty, ret::reference)
|
||||||
|
Reference in New Issue
Block a user