Reverts back to MLIR 14 & updates CMakeLists

This commit is contained in:
Yan Da
2022-03-20 16:41:48 +08:00
parent a2c31ff434
commit 419bbe0f6e
14 changed files with 107 additions and 61 deletions

View File

@@ -17,7 +17,7 @@ from setuptools.command.build_ext import build_ext
def get_llvm():
# tries to find system LLVM
versions = ['-11.0', '-11', '-11-64']
versions = ['-14.0', '-14', '-14-64']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None]
@@ -26,7 +26,8 @@ def get_llvm():
if platform.system() == "Windows":
return '', ''
# download if nothing is installed
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
# name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
name = 'clang+llvm-14.0.0-rc2-x86_64-linux-gnu-ubuntu-18.04'
dir = '/tmp'
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
@@ -35,7 +36,7 @@ def get_llvm():
shutil.rmtree(os.path.join(dir, name))
except Exception:
pass
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0-rc2/{name}.tar.xz".format(name=name)
print('downloading and extracting ' + url + '...')
ftpstream = urllib.request.urlopen(url)
file = tarfile.open(fileobj=ftpstream, mode="r|xz")

View File

@@ -36,13 +36,13 @@ std::vector<int> segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width
std::vector<int> current(H, 0);
int num = 0;
std::vector<int> lut(H * M * N * 4);
for (size_t h = 0; h < H; h++) {
for (ssize_t h = 0; h < H; h++) {
// surrounding indices
std::vector<int> ii_left(max_width, -1);
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
// start the dynamic programming algorithm
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
for (ssize_t m = 0; m < M; m++) {
for (ssize_t n = 0; n < N; n++) {
int v = layout(h, m, n);
if (v == 0)
continue;
@@ -70,8 +70,8 @@ std::vector<int> segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width
if (width != max_width)
continue;
// retained blocks are set to zeros
for (size_t km = 0; km < max_width; km++)
for (size_t kn = 0; kn < max_width; kn++) {
for (ssize_t km = 0; km < max_width; km++)
for (ssize_t kn = 0; kn < max_width; kn++) {
int mm = ii_top[km][n];
int nn = ii_left[kn];
if (mm < 0 || nn < 0)
@@ -116,4 +116,4 @@ std::vector<lut_t> superblock(uintptr_t LAYOUT, int H, int M, int N, int start_w
void init_superblocking(pybind11::module &m) {
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
}
}

View File

@@ -2,11 +2,15 @@
#include "triton/codegen/target.h"
#include "triton/driver/error.h"
#include "triton/driver/llvm.h"
#include "triton/ir/builder.h"
#include "triton/ir/enums.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/print.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
#include <optional>
#include <pybind11/buffer_info.h>
#include <pybind11/functional.h>
@@ -18,9 +22,6 @@
#include <sstream>
#include <stdexcept>
#include <string>
#include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
namespace py = pybind11;
namespace ir = triton::ir;
@@ -748,7 +749,7 @@ void init_triton_ir(py::module &&m) {
}, ret::reference)
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
.def(py::init<ir::context &>())
// getters
.def_property_readonly("context", &ir::builder::get_context, ret::reference)
@@ -788,10 +789,10 @@ void init_triton_ir(py::module &&m) {
.def("get_range", &ir::builder::get_range, ret::reference)
// Types
.def("get_void_ty", &ir::builder::get_void_ty, ret::reference)
.def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference)
.def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference)
.def("get_int1_ty", &mlir::OpBuilder::getI1Type, ret::reference)
.def("get_int8_ty", &mlir::OpBuilder::getI8Type, ret::reference)
.def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference)
.def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference)
.def("get_int32_ty", &mlir::OpBuilder::getI32Type, ret::reference)
.def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference)
.def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference)
.def("get_half_ty", &ir::builder::get_half_ty, ret::reference)