Replace MlirType with mlir::Type

This commit is contained in:
Yan Da
2022-04-01 18:46:46 +08:00
parent 4ad432f1fc
commit bde103fab0
3 changed files with 93 additions and 239 deletions

View File

@@ -4,11 +4,8 @@
#include "triton/driver/llvm.h"
#include "mlir/IR/Builders.h"
#include "mlir-c/IR.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir/CAPI/IR.h"
// #include "mlir/IR/BuiltinOps.h"
// #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "triton/ir/Dialect.h"
#include "triton/ir/Types.h"
@@ -652,13 +649,9 @@ void init_triton_ir(py::module &&m) {
.def("get_context", &mlir::ModuleOp::getContext)
;
py::class_<MlirType>(m, "type")
.def("is_integer", [](MlirType &self) -> bool {
return mlirTypeIsAInteger(self);
})
.def("is_fp16", [](MlirType &self) -> bool {
return mlirTypeIsABF16(self);
})
py::class_<mlir::Type>(m, "type")
.def("is_integer", &mlir::Type::isInteger)
.def("is_fp16", &mlir::Type::isF16)
;
py::class_<mlir::Value>(m, "value")
@@ -782,74 +775,77 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::ConstantOp>(loc, self.getF32FloatAttr(v));
})
.def("get_null_value", [](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
auto loc = self.getUnknownLoc();
if (type.isa<mlir::FloatType>())
return self.create<mlir::arith::ConstantOp>(loc, self.getF32FloatAttr(0.0));
else
throw std::runtime_error("Not implemented");
})
// Types
.def("get_void_ty", [](mlir::OpBuilder &self) ->MlirType {
return wrap(self.getNoneType());
.def("get_void_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getNoneType();
})
.def("get_int1_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getI1Type());
.def("get_int1_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getI1Type();
}) // or ret::copy?
.def("get_int8_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getI8Type());
.def("get_int8_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getI8Type();
})
.def("get_int16_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getType<mlir::IntegerType>(16));
.def("get_int16_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getType<mlir::IntegerType>(16);
})
.def("get_int32_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getI32Type());
.def("get_int32_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getI32Type();
})
.def("get_int64_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getI64Type());
.def("get_int64_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getI64Type();
})
.def("get_fp8_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getType<mlir::triton::Float8Type>());
.def("get_fp8_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getType<mlir::triton::Float8Type>();
})
.def("get_bf8_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getType<mlir::triton::BFloat8Type>());
.def("get_bf8_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getType<mlir::triton::BFloat8Type>();
})
.def("get_half_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getF16Type());
.def("get_half_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getF16Type();
})
.def("get_bf16_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getBF16Type());
.def("get_bf16_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getBF16Type();
})
.def("get_float_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getF32Type());
.def("get_float_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getF32Type();
})
.def("get_double_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getF64Type());
.def("get_double_ty", [](mlir::OpBuilder &self) -> mlir::Type {
return self.getF64Type();
})
.def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type, int addrSpace) -> MlirType {
return wrap(
mlir::triton::PointerType::get(unwrap(type), addrSpace)
);
.def("get_ptr_ty", [](mlir::OpBuilder &self, mlir::Type &type, int addrSpace) -> mlir::Type {
return mlir::triton::PointerType::get(type, addrSpace);
})
.def("get_block_ty", [](mlir::OpBuilder &self, MlirType &elementType,
std::vector<int64_t> &shape) -> MlirType {
return wrap(
mlir::RankedTensorType::get(shape, unwrap(elementType))
);
.def("get_block_ty", [](mlir::OpBuilder &self, mlir::Type &elementType,
std::vector<int64_t> &shape) -> mlir::Type {
return mlir::RankedTensorType::get(shape, elementType);
})
.def("get_function_ty", [](mlir::OpBuilder &self,
std::vector<MlirType> inTypes,
std::vector<MlirType> outTypes) -> MlirType {
llvm::SmallVector<mlir::Type, 4> inputsTypeList;
llvm::SmallVector<mlir::Type, 4> resultsTypeList;
(void)unwrapList(inTypes.size(), inTypes.data(), inputsTypeList);
(void)unwrapList(outTypes.size(), outTypes.data(), resultsTypeList);
return wrap(self.getFunctionType(inputsTypeList, resultsTypeList));
std::vector<mlir::Type> inTypes,
std::vector<mlir::Type> outTypes) -> mlir::Type {
return self.getFunctionType(inTypes, outTypes);
})
// Ops
.def("create_function", [](mlir::OpBuilder &self, std::string name, MlirType funcType) -> mlir::FuncOp {
.def("create_function", [](mlir::OpBuilder &self, std::string name, mlir::Type &funcType) -> mlir::FuncOp {
// TODO: loc
auto loc = self.getUnknownLoc();
if (auto funcTy = unwrap(funcType).dyn_cast<mlir::FunctionType>()) {
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
return self.create<mlir::FuncOp>(loc, name, funcTy);
}
throw std::runtime_error("invalid function type");
})
.def("create_block", [](mlir::OpBuilder &self) -> mlir::Block* {
mlir::Region *parent = self.getBlock()->getParent();
return self.createBlock(parent);
}, ret::reference)
// Structured control flow
.def("create_for_op", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub,
mlir::Value &step, std::vector<mlir::Value> &initArgs) -> mlir::scf::ForOp {
@@ -878,35 +874,35 @@ void init_triton_ir(py::module &&m) {
})
// Cast instructions
.def("create_bitcast", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
.def("create_bitcast", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::BitcastOp>(loc, unwrap(dstType), src);
return self.create<mlir::arith::BitcastOp>(loc, dstType, src);
})
// .def("create_cast", &ir::builder::create_cast)
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
.def("create_si_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
.def("create_si_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::SIToFPOp>(loc, unwrap(dstType), src);
return self.create<mlir::arith::SIToFPOp>(loc, dstType, src);
})
.def("create_ui_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
.def("create_ui_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::UIToFPOp>(loc, unwrap(dstType), src);
return self.create<mlir::arith::UIToFPOp>(loc, dstType, src);
})
.def("create_fp_to_si", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
.def("create_fp_to_si", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::FPToSIOp>(loc, unwrap(dstType), src);
return self.create<mlir::arith::FPToSIOp>(loc, dstType, src);
})
.def("create_fp_to_ui", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
.def("create_fp_to_ui", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::FPToUIOp>(loc, unwrap(dstType), src);
return self.create<mlir::arith::FPToUIOp>(loc, dstType, src);
})
.def("create_fp_ext", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
.def("create_fp_ext", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::ExtFOp>(loc, unwrap(dstType), src);
return self.create<mlir::arith::ExtFOp>(loc, dstType, src);
})
.def("create_fp_trunc", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
.def("create_fp_trunc", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::TruncFOp>(loc, unwrap(dstType), src);
return self.create<mlir::arith::TruncFOp>(loc, dstType, src);
})
// .def("create_int_cast", &ir::builder::create_int_cast)
// .def("create_downcast", &ir::builder::create_downcast)