Now vecadd works

This commit is contained in:
Yan Da
2022-03-30 20:21:47 +08:00
parent e381dc72c5
commit 2041b67fbf
5 changed files with 285 additions and 386 deletions

View File

@@ -114,7 +114,9 @@ def TT_EvictionPolicyAttr : I32EnumAttr<
def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> { def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> {
let summary = "load"; let summary = "load";
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other); let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile);
let results = (outs TT_Type:$result); let results = (outs TT_Type:$result);

View File

@@ -641,21 +641,15 @@ void init_triton_ir(py::module &&m) {
// // py::class_<ir::undef_value, ir::constant>(m, "undef") // // py::class_<ir::undef_value, ir::constant>(m, "undef")
// // .def("get", &ir::undef_value::get, ret::reference); // // .def("get", &ir::undef_value::get, ret::reference);
py::class_<MlirModule>(m, "module") py::class_<mlir::ModuleOp>(m, "module")
// .def("set_attr") // .def("set_attr")
.def("dump", [](MlirModule &self) -> void { .def("dump", [](mlir::ModuleOp &self) -> void {
unwrap(self).dump(); self.dump();
}) })
.def("push_back", [](MlirModule &self, MlirOperation &funcOperation) { .def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
if (auto info = unwrap(funcOperation)->getRegisteredInfo()) { self.push_back(funcOp);
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(funcOperation));
unwrap(self).push_back(funcOp);
} else
throw std::runtime_error("Only FuncOp can call push_back");
} else
throw std::runtime_error("Unknown error");
}) })
.def("get_context", &mlir::ModuleOp::getContext)
; ;
py::class_<MlirType>(m, "type") py::class_<MlirType>(m, "type")
@@ -667,23 +661,6 @@ void init_triton_ir(py::module &&m) {
}) })
; ;
py::class_<MlirOperation>(m, "operation")
.def("add_entry_block", [](MlirOperation &self) -> mlir::Block {
if (auto info = unwrap(self)->getRegisteredInfo()) {
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self));
mlir::Block *entry = funcOp.addEntryBlock();
return *entry;
}
throw std::runtime_error("Only FuncOp can call add_entry_block");
} else
throw std::runtime_error("Unknown error");
}) // this should be automatic?
.def("dump", [](MlirOperation &self) -> void {
unwrap(self)->dump();
})
;
py::class_<mlir::Value>(m, "value") py::class_<mlir::Value>(m, "value")
; ;
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement") py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement")
@@ -693,6 +670,7 @@ void init_triton_ir(py::module &&m) {
.def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument { .def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument {
return self.getArgument(index); return self.getArgument(index);
}) })
.def("dump", &mlir::Block::dump)
; ;
// py::class_<mlir::ModuleOp>(m, "module") // py::class_<mlir::ModuleOp>(m, "module")
@@ -720,29 +698,34 @@ void init_triton_ir(py::module &&m) {
// py::class_<mlir::Attribute>(m, "attribute"); // py::class_<mlir::Attribute>(m, "attribute");
// // .def(py::init<eattr, int>()); // // .def(py::init<eattr, int>());
// py::class_<mlir::FuncOp>(m, "function") py::class_<mlir::FuncOp>(m, "function")
// .def_property_readonly("args", &ir::function::args) // .def_property_readonly("args", &ir::function::args)
// .def_property_readonly("attrs", &ir::function::attrs) // .def_property_readonly("attrs", &ir::function::attrs)
// .def("add_attr", &ir::function::add_attr); // .def("add_attr", &ir::function::add_attr);
.def("args", [](mlir::FuncOp &self, unsigned idx) -> mlir::BlockArgument {
// // // We don't need to expose mlir::Block (?) return self.getArgument(idx);
// // py::class_<mlir::Block>(m, "basic_block") })
// // // .def("create", &ir::basic_block::create, ret::reference) .def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
// // .def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference) return self.addEntryBlock();
// // .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference); }, ret::reference)
.def("dump", [](mlir::FuncOp &self) { self.dump(); })
;
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr()) py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
.def(py::init<mlir::MLIRContext *>()) .def(py::init<mlir::MLIRContext *>())
// // getters // // getters
// .def_property_readonly("context", &ir::builder::get_context, ret::reference); // .def_property_readonly("context", &ir::builder::get_context, ret::reference);
.def("create_module", [](mlir::OpBuilder &self) -> MlirModule { .def("create_module", [](mlir::OpBuilder &self) -> mlir::ModuleOp {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::ModuleOp>(loc)); return self.create<mlir::ModuleOp>(loc);
}) })
// // control flow // // control flow
// .def("br", &ir::builder::create_br, ret::reference) // .def("br", &ir::builder::create_br, ret::reference)
// .def("cond_br", &ir::builder::create_cond_br, ret::reference) // .def("cond_br", &ir::builder::create_cond_br, ret::reference)
// .def("ret_void", &ir::builder::create_ret_void, ret::reference) .def("ret_void", [](mlir::OpBuilder &self) {
auto loc = self.getUnknownLoc();
self.create<mlir::ReturnOp>(loc);
}, ret::reference)
// insertion block/point // insertion block/point
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void { .def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void {
self.setInsertionPointToStart(&block); self.setInsertionPointToStart(&block);
@@ -750,8 +733,8 @@ void init_triton_ir(py::module &&m) {
.def("set_insertion_point_to_end", [](mlir::OpBuilder &self, mlir::Block &block) { .def("set_insertion_point_to_end", [](mlir::OpBuilder &self, mlir::Block &block) {
self.setInsertionPointToEnd(&block); self.setInsertionPointToEnd(&block);
}) })
.def("get_insertion_block", [](mlir::OpBuilder &self) -> mlir::Block & { .def("get_insertion_block", [](mlir::OpBuilder &self) -> mlir::Block* {
return *self.getInsertionBlock(); return self.getInsertionBlock();
}, ret::reference) }, ret::reference)
// .def("get_insert_point", [](ir::builder *self) { // .def("get_insert_point", [](ir::builder *self) {
// ir::basic_block *bb = self->get_insert_block(); // ir::basic_block *bb = self->get_insert_block();
@@ -784,8 +767,10 @@ void init_triton_ir(py::module &&m) {
// .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference) // .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
// .def("get_uint64", &ir::builder::get_int64, ret::reference) // .def("get_uint64", &ir::builder::get_int64, ret::reference)
// .def("get_float16", &ir::builder::get_float16, ret::reference) // .def("get_float16", &ir::builder::get_float16, ret::reference)
// .def("get_float32", &ir::builder::get_float32, ret::reference) .def("get_float32", [](mlir::OpBuilder &self, float v) -> mlir::Value {
// .def("get_range", &ir::builder::get_range, ret::reference) auto loc = self.getUnknownLoc();
return self.create<mlir::arith::ConstantOp>(loc, self.getF32FloatAttr(v));
})
// Types // Types
.def("get_void_ty", [](mlir::OpBuilder &self) ->MlirType { .def("get_void_ty", [](mlir::OpBuilder &self) ->MlirType {
@@ -846,22 +831,22 @@ void init_triton_ir(py::module &&m) {
}) })
// Ops // Ops
.def("create_function", [](mlir::OpBuilder &self, std::string name, MlirType funcType) -> MlirOperation { .def("create_function", [](mlir::OpBuilder &self, std::string name, MlirType funcType) -> mlir::FuncOp {
// TODO: loc // TODO: loc
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
if (auto funcTy = unwrap(funcType).dyn_cast<mlir::FunctionType>()) { if (auto funcTy = unwrap(funcType).dyn_cast<mlir::FunctionType>()) {
return wrap(self.create<mlir::FuncOp>(loc, name, funcTy)); return self.create<mlir::FuncOp>(loc, name, funcTy);
} }
throw std::runtime_error("invalid function type"); throw std::runtime_error("invalid function type");
}) })
// // Structured control flow // // Structured control flow
// .def("create_for", [](mlir::OpBuilder &self, MlirValue &lb, MlirValue &ub, // .def("create_for", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub,
// MlirValue &step, std::vector<MlirValue> &initArgs) -> MlirOperation { // mlir::Value &step, std::vector<mlir::Value> &initArgs) -> MlirOperation {
// auto loc = self.getUnknownLoc(); // auto loc = self.getUnknownLoc();
// return wrap(self.create<mlir::scf::ForOp>( // return wrap(self.create<mlir::scf::ForOp>(
// loc, unwrap(lb), unwrap(ub), unwrap(step)).getOperation()); // loc, unwrap(lb), unwrap(ub), unwrap(step)).getOperation());
// }) // })
// .def("create_if", [](mlir::OpBuilder &self, MlirValue &condition) -> MlirOperation { // .def("create_if", [](mlir::OpBuilder &self, mlir::Value &condition) -> MlirOperation {
// auto loc = self.getUnknownLoc(); // auto loc = self.getUnknownLoc();
// return wrap(self.create<mlir::scf::IfOp>(loc, unwrap(condition)).getOperation()); // return wrap(self.create<mlir::scf::IfOp>(loc, unwrap(condition)).getOperation());
// }) // })
@@ -872,428 +857,334 @@ void init_triton_ir(py::module &&m) {
// // .def("create_while") // // .def("create_while")
// miscellious // miscellious
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> MlirValue { .def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
auto retType = mlir::RankedTensorType::get({end-start}, self.getI32Type()); auto retType = mlir::RankedTensorType::get({end-start}, self.getI32Type());
return wrap( return self.create<mlir::triton::MakeRangeOp>(loc, retType, start, end);
mlir::Value(self.create<mlir::triton::MakeRangeOp>(loc, retType, start, end))
);
}) })
.def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> MlirValue { .def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap( return self.create<mlir::triton::GetProgramIdOp>(loc, self.getI32Type(), axis);
mlir::Value(self.create<mlir::triton::GetProgramIdOp>(loc, self.getI32Type(), axis))
);
}) })
// Cast instructions // Cast instructions
.def("create_bitcast", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { .def("create_bitcast", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value( return self.create<mlir::arith::BitcastOp>(loc, unwrap(dstType), src);
self.create<mlir::arith::BitcastOp>(loc, unwrap(dstType), unwrap(src))
));
}) })
// .def("create_cast", &ir::builder::create_cast) // .def("create_cast", &ir::builder::create_cast)
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int) // .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
.def("create_si_to_fp", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { .def("create_si_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value( return self.create<mlir::arith::SIToFPOp>(loc, unwrap(dstType), src);
self.create<mlir::arith::SIToFPOp>(loc, unwrap(dstType), unwrap(src))
));
}) })
.def("create_ui_to_fp", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { .def("create_ui_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value( return self.create<mlir::arith::UIToFPOp>(loc, unwrap(dstType), src);
self.create<mlir::arith::UIToFPOp>(loc, unwrap(dstType), unwrap(src))
));
}) })
.def("create_fp_to_si", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { .def("create_fp_to_si", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value( return self.create<mlir::arith::FPToSIOp>(loc, unwrap(dstType), src);
self.create<mlir::arith::FPToSIOp>(loc, unwrap(dstType), unwrap(src))
));
}) })
.def("create_fp_to_ui", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { .def("create_fp_to_ui", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value( return self.create<mlir::arith::FPToUIOp>(loc, unwrap(dstType), src);
self.create<mlir::arith::FPToUIOp>(loc, unwrap(dstType), unwrap(src))
));
}) })
.def("create_fp_ext", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { .def("create_fp_ext", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value( return self.create<mlir::arith::ExtFOp>(loc, unwrap(dstType), src);
self.create<mlir::arith::ExtFOp>(loc, unwrap(dstType), unwrap(src))
));
}) })
.def("create_fp_trunc", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { .def("create_fp_trunc", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value( return self.create<mlir::arith::TruncFOp>(loc, unwrap(dstType), src);
self.create<mlir::arith::TruncFOp>(loc, unwrap(dstType), unwrap(src))
));
}) })
// .def("create_int_cast", &ir::builder::create_int_cast) // .def("create_int_cast", &ir::builder::create_int_cast)
// .def("create_downcast", &ir::builder::create_downcast) // .def("create_downcast", &ir::builder::create_downcast)
.def("create_fmul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fmul", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value( return self.create<mlir::arith::MulFOp>(loc, lhs, rhs);
self.create<mlir::arith::MulFOp>(loc, unwrap(lhs), unwrap(rhs))
));
}) })
.def("create_fdiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fdiv", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value( return self.create<mlir::arith::DivFOp>(loc, lhs, rhs);
self.create<mlir::arith::DivFOp>(loc, unwrap(lhs), unwrap(rhs))
));
}) })
.def("create_frem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_frem", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value( return self.create<mlir::arith::RemFOp>(loc, lhs, rhs);
self.create<mlir::arith::RemFOp>(loc, unwrap(lhs), unwrap(rhs))
));
}) })
.def("create_fadd", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fadd", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value( return self.create<mlir::arith::AddFOp>(loc, lhs, rhs);
self.create<mlir::arith::AddFOp>(loc, unwrap(lhs), unwrap(rhs))
));
}) })
.def("create_fsub", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fsub", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value( return self.create<mlir::arith::SubFOp>(loc, lhs, rhs);
self.create<mlir::arith::SubFOp>(loc, unwrap(lhs), unwrap(rhs))
));
}) })
.def("create_mul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_mul", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
// Check lhs & rhs have single result (?) return self.create<mlir::arith::MulIOp>(loc, lhs, rhs);
return wrap(
mlir::Value(self.create<mlir::arith::MulIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}) })
.def("create_sdiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_sdiv", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap( return self.create<mlir::arith::DivSIOp>(loc, lhs, rhs);
mlir::Value(self.create<mlir::arith::DivSIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}) })
.def("create_udiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_udiv", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap( return self.create<mlir::arith::DivUIOp>(loc, lhs, rhs);
mlir::Value(self.create<mlir::arith::DivUIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}) })
.def("create_srem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_srem", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap( return self.create<mlir::arith::RemSIOp>(loc, lhs, rhs);
mlir::Value(self.create<mlir::arith::RemSIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}) })
.def("create_urem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_urem", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap( return self.create<mlir::arith::RemUIOp>(loc, lhs, rhs);
mlir::Value(self.create<mlir::arith::RemUIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}) })
.def("create_add", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_add", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap( return self.create<mlir::arith::AddIOp>(loc, lhs, rhs);
mlir::Value(self.create<mlir::arith::AddIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}) })
.def("create_sub", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_sub", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap( return mlir::Value(self.create<mlir::arith::SubIOp>(loc, lhs, rhs));
mlir::Value(self.create<mlir::arith::SubIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}) })
.def("create_shl", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_shl", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap( return mlir::Value(self.create<mlir::arith::ShLIOp>(loc, lhs, rhs));
mlir::Value(self.create<mlir::arith::ShLIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}) })
.def("create_lshr", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_lshr", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap( return mlir::Value(self.create<mlir::arith::ShRUIOp>(loc, lhs, rhs));
mlir::Value(self.create<mlir::arith::ShRUIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}) })
.def("create_ashr", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_ashr", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap( return mlir::Value(self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs));
mlir::Value(self.create<mlir::arith::ShRSIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}) })
// GEP // GEP
.def("create_gep", [](mlir::OpBuilder &self, MlirValue &ptr, MlirValue &offset) -> MlirValue { .def("create_gep", [](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &offset) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap( return self.create<mlir::triton::GEPOp>(loc, ptr.getType(), ptr, offset);
mlir::Value(self.create<mlir::triton::GEPOp>(loc, unwrap(ptr).getType(), unwrap(ptr), unwrap(offset)))
);
}) })
// Comparison (int) // Comparison (int)
.def("create_icmpSLE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_icmpSLE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>( return self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sle, loc, mlir::arith::CmpIPredicate::sle, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_icmpSLT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_icmpSLT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>( return self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::slt, loc, mlir::arith::CmpIPredicate::slt, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_icmpSGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_icmpSGE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>( return self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sge, loc, mlir::arith::CmpIPredicate::sge, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_icmpSGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_icmpSGT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>( return self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sgt, loc, mlir::arith::CmpIPredicate::sgt, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_icmpULE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_icmpULE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>( return self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ule, loc, mlir::arith::CmpIPredicate::ule, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_icmpULT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_icmpULT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>( return self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ult, loc, mlir::arith::CmpIPredicate::ult, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_icmpUGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_icmpUGE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>( return self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::uge, loc, mlir::arith::CmpIPredicate::uge, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_icmpUGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_icmpUGT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>( return self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ugt, loc, mlir::arith::CmpIPredicate::ugt, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_icmpEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_icmpEQ", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>( return self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::eq, loc, mlir::arith::CmpIPredicate::eq, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_icmpNE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_icmpNE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>( return self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ne, loc, mlir::arith::CmpIPredicate::ne, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
// Comparison (float) // Comparison (float)
.def("create_fcmpOLT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fcmpOLT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>( return self.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OLT, loc, mlir::arith::CmpFPredicate::OLT, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_fcmpOGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fcmpOGT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>( return self.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OGT, loc, mlir::arith::CmpFPredicate::OGT, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_fcmpOLE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fcmpOLE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>( return self.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OLE, loc, mlir::arith::CmpFPredicate::OLE, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_fcmpOGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fcmpOGE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>( return self.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OGE, loc, mlir::arith::CmpFPredicate::OGE, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_fcmpOEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fcmpOEQ", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>( return self.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OEQ, loc, mlir::arith::CmpFPredicate::OEQ, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_fcmpONE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fcmpONE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>( return self.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::ONE, loc, mlir::arith::CmpFPredicate::ONE, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_fcmpULT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fcmpULT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>( return self.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::ULT, loc, mlir::arith::CmpFPredicate::ULT, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_fcmpUGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fcmpUGT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>( return self.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::UGT, loc, mlir::arith::CmpFPredicate::UGT, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_fcmpULE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fcmpULE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>( return self.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::ULE, loc, mlir::arith::CmpFPredicate::ULE, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_fcmpUGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fcmpUGE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>( return self.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::UGE, loc, mlir::arith::CmpFPredicate::UGE, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_fcmpUEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fcmpUEQ", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>( return self.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::UEQ, loc, mlir::arith::CmpFPredicate::UEQ, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_fcmpUNE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_fcmpUNE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>( return self.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::UNE, loc, mlir::arith::CmpFPredicate::UNE, lhs, rhs);
unwrap(lhs), unwrap(rhs)
)));
}) })
// // Logical // // Logical
.def("create_and", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_and", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::AndIOp>( return self.create<mlir::arith::AndIOp>(loc, lhs, rhs);
loc, unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_xor", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_xor", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::XOrIOp>( return self.create<mlir::arith::XOrIOp>(loc, lhs, rhs);
loc, unwrap(lhs), unwrap(rhs)
)));
}) })
.def("create_or", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_or", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::OrIOp>( return self.create<mlir::arith::OrIOp>(loc, lhs, rhs);
loc, unwrap(lhs), unwrap(rhs)
)));
}) })
// // Input/Output // // Input/Output
.def("create_load", [](mlir::OpBuilder &self, MlirValue &ptrs) -> MlirValue { .def("create_load", [](mlir::OpBuilder &self, mlir::Value &ptrs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value( return self.create<mlir::triton::LoadOp>(loc, ptrs);
self.create<mlir::triton::LoadOp>(loc, unwrap(ptrs))
));
}) })
.def("create_store", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &value) -> void { .def("create_store", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &value) -> void {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
self.create<mlir::triton::StoreOp>(loc, unwrap(ptrs), unwrap(value)); self.create<mlir::triton::StoreOp>(loc, ptrs, value);
}) })
.def("create_masked_load", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &mask, MlirValue &other) -> MlirValue { .def("create_masked_load", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &mask, mlir::Value &other,
mlir::triton::CacheModifier cacheModifier,
mlir::triton::EvictionPolicy evictionPolicy,
bool isVolatile) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
auto ptrType = unwrap(ptrs).getType().dyn_cast<mlir::RankedTensorType>(); auto ptrType = ptrs.getType().dyn_cast<mlir::RankedTensorType>();
std::vector<int64_t> shape = ptrType.getShape(); std::vector<int64_t> shape = ptrType.getShape();
mlir::Type elementType = ptrType.getElementType().dyn_cast<mlir::triton::PointerType>().getPointeeType(); mlir::Type elementType = ptrType.getElementType().dyn_cast<mlir::triton::PointerType>().getPointeeType();
return wrap(mlir::Value(self.create<mlir::triton::LoadOp>( return self.create<mlir::triton::LoadOp>(
loc, mlir::RankedTensorType::get(shape, elementType), unwrap(ptrs), unwrap(mask), unwrap(other)) loc, mlir::RankedTensorType::get(shape, elementType), ptrs, mask, other,
)); cacheModifier, evictionPolicy, isVolatile);
}) })
.def("create_masked_store", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &val, MlirValue &mask) -> void { .def("create_masked_store", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val, mlir::Value &mask) -> void {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
self.create<mlir::triton::StoreOp>(loc, unwrap(ptrs), unwrap(val), unwrap(mask)); self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
}) })
// Block instruction // Block instruction
.def("create_reshape", [](mlir::OpBuilder &self, MlirValue &arg, std::vector<int64_t> &shape) -> MlirValue { .def("create_reshape", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
auto argType = unwrap(arg).getType().dyn_cast<mlir::RankedTensorType>(); auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>().getElementType();
return wrap(mlir::Value(self.create<mlir::triton::ReshapeOp>( return self.create<mlir::triton::ReshapeOp>(
loc, mlir::RankedTensorType::get(shape, argType), unwrap(arg), self.getI64ArrayAttr(shape) loc, mlir::RankedTensorType::get(shape, argType), arg, self.getI64ArrayAttr(shape)
))); );
}) })
.def("create_cat", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { .def("create_cat", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
auto lhsType = unwrap(lhs).getType().dyn_cast<mlir::RankedTensorType>(); auto lhsType = lhs.getType().dyn_cast<mlir::RankedTensorType>();
auto rhsType = unwrap(rhs).getType().dyn_cast<mlir::RankedTensorType>(); auto rhsType = rhs.getType().dyn_cast<mlir::RankedTensorType>();
if (!(lhsType.getShape().size() == 1 && rhsType.getShape().size() == 1)) if (!(lhsType.getShape().size() == 1 && rhsType.getShape().size() == 1))
throw std::runtime_error("shape not supported by cat. Expecting rank-1 inputs"); throw std::runtime_error("shape not supported by cat. Expecting rank-1 inputs");
std::vector<int64_t> shape {lhsType.getShape()[0] + rhsType.getShape()[0]}; std::vector<int64_t> shape {lhsType.getShape()[0] + rhsType.getShape()[0]};
return wrap(mlir::Value(self.create<mlir::triton::CatOp>( return self.create<mlir::triton::CatOp>(
loc, mlir::RankedTensorType::get(shape, lhsType.getElementType()), unwrap(lhs), unwrap(rhs) loc, mlir::RankedTensorType::get(shape, lhsType.getElementType()), lhs, rhs
))); );
}) })
.def("create_broadcast", [](mlir::OpBuilder &self, MlirValue &arg, std::vector<int64_t> &shape) -> MlirValue { .def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
auto argType = unwrap(arg).getType(); // TODO: should be scalar type here
return wrap(mlir::Value(self.create<mlir::triton::BroadcastOp>( auto argType = arg.getType();
loc, mlir::RankedTensorType::get(shape, argType), unwrap(arg) return self.create<mlir::triton::BroadcastOp>(
))); loc, mlir::RankedTensorType::get(shape, argType), arg
);
})
.def("create_splat", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto argType = arg.getType();
return self.create<mlir::triton::BroadcastOp>(
loc, mlir::RankedTensorType::get(shape, argType), arg
);
}) })
// // atomic // // atomic
.def("create_atomic_cas", [](mlir::OpBuilder &self, MlirValue &ptr, .def("create_atomic_cas", [](mlir::OpBuilder &self, mlir::Value &ptr,
MlirValue &cmp, MlirValue &val) -> MlirValue { mlir::Value &cmp, mlir::Value &val) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
auto ptrType = unwrap(ptr).getType().dyn_cast<mlir::triton::PointerType>(); auto ptrType = ptr.getType().dyn_cast<mlir::triton::PointerType>();
mlir::Type dstType = ptrType.getPointeeType(); mlir::Type dstType = ptrType.getPointeeType();
return wrap(mlir::Value(self.create<mlir::triton::AtomicCASOp>( return self.create<mlir::triton::AtomicCASOp>(
loc, dstType, unwrap(ptr), unwrap(cmp), unwrap(val) loc, dstType, ptr, cmp, val
))); );
}) })
.def("create_atomic_rmw", [](mlir::OpBuilder &self, mlir::triton::RMWOp rmwOp, .def("create_atomic_rmw", [](mlir::OpBuilder &self, mlir::triton::RMWOp rmwOp,
MlirValue &ptr, MlirValue &val, MlirValue &mask) -> MlirValue { mlir::Value &ptr, mlir::Value &val, mlir::Value &mask) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
auto ptrType = unwrap(ptr).getType().dyn_cast<mlir::triton::PointerType>(); auto ptrType = ptr.getType().dyn_cast<mlir::triton::PointerType>();
mlir::Type dstType = ptrType.getPointeeType(); mlir::Type dstType = ptrType.getPointeeType();
return wrap(mlir::Value(self.create<mlir::triton::AtomicRMWOp>( return self.create<mlir::triton::AtomicRMWOp>(
loc, dstType, rmwOp, unwrap(ptr), unwrap(val), unwrap(mask) loc, dstType, rmwOp, ptr, val, mask
))); );
}) })
// Built-in instruction // Built-in instruction
.def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> MlirValue { .def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::triton::GetProgramIdOp>( return self.create<mlir::triton::GetProgramIdOp>(
loc, self.getI32Type(), self.getI32IntegerAttr(axis) loc, self.getI32Type(), self.getI32IntegerAttr(axis)
))); );
}) })
.def("create_get_num_programs", [](mlir::OpBuilder &self, int axis) -> MlirValue { .def("create_get_num_programs", [](mlir::OpBuilder &self, int axis) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::triton::GetNumProgramsOp>( return self.create<mlir::triton::GetNumProgramsOp>(
loc, self.getI32Type(), self.getI32IntegerAttr(axis) loc, self.getI32Type(), self.getI32IntegerAttr(axis)
))); );
}) })
.def("create_dot", [](mlir::OpBuilder &self, MlirValue &a, MlirValue &b, MlirValue &c) -> MlirValue { .def("create_dot", [](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b, mlir::Value &c) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::triton::DotOp>( return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c);
loc, unwrap(c).getType(), unwrap(a), unwrap(b), unwrap(c)
)));
}) })
// .def("create_exp", &ir::builder::create_exp, ret::reference) // .def("create_exp", &ir::builder::create_exp, ret::reference)
// .def("create_cos", &ir::builder::create_cos, ret::reference) // .def("create_cos", &ir::builder::create_cos, ret::reference)

View File

@@ -67,9 +67,10 @@ class CodeGenerator(ast.NodeVisitor):
elif name in self.builtins: elif name in self.builtins:
ret = self.builtins[name] ret = self.builtins[name]
else: else:
print(self.lscope)
raise ValueError(f'{name} is not defined') raise ValueError(f'{name} is not defined')
if self.is_triton_tensor(ret): if self.is_triton_tensor(ret):
return self._get_tensor(name) return self._get_tensor(name, self.builder.get_insertion_block())
return ret return ret
def set_value(self, name: str, def set_value(self, name: str,
@@ -86,12 +87,15 @@ class CodeGenerator(ast.NodeVisitor):
# #
# SSA-construction # SSA-construction
# #
def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor: def _get_tensor(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
if not bb: if not bb:
bb = self.builder.get_insertion_block() bb = self.builder.get_insertion_block()
# local value numbering # local value numbering
if (name, bb) in self.lvalues: if (name, bb) in self.lvalues:
return self.lvalues[(name, bb)] return self.lvalues[(name, bb)]
# param. FIXME: should delete this
if (name, None) in self.lvalues:
return self.lvalues[(name, None)]
print(self.lvalues) print(self.lvalues)
assert False, f'Cannot find {name} in {bb}' assert False, f'Cannot find {name} in {bb}'
# global value numbering # global value numbering
@@ -217,10 +221,15 @@ class CodeGenerator(ast.NodeVisitor):
self.lscope[kwarg_names] = self.kwargs self.lscope[kwarg_names] = self.kwargs
# initialize function # initialize function
if inline: if inline:
pass for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
self.visit_compound_statement(node.body)
return self.last_ret
else: else:
fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder)) fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder))
self.module.push_back(fn) self.module.push_back(fn)
entry = fn.add_entry_block()
self._seal_block(entry)
arg_values = [] arg_values = []
idx = 0 idx = 0
for i, arg_name in enumerate(arg_names): for i, arg_name in enumerate(arg_names):
@@ -239,17 +248,11 @@ class CodeGenerator(ast.NodeVisitor):
# attr = _triton.ir.attribute(attr, self.attributes[i]) # attr = _triton.ir.attribute(attr, self.attributes[i])
# fn.add_attr(idx + 1, attr) # fn.add_attr(idx + 1, attr)
# fn.args[idx].name = arg_name # fn.args[idx].name = arg_name
# arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx])) arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
# idx += 1 idx += 1
for arg_name, arg_value in zip(arg_names, arg_values): for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value) self.set_value(arg_name, arg_value)
if inline:
self.visit_compound_statement(node.body)
return self.last_ret
else:
entry = fn.add_entry_block()
self._seal_block(entry)
self.builder.set_insertion_point_to_start(entry) self.builder.set_insertion_point_to_start(entry)
# visit function body # visit function body
self.visit_compound_statement(node.body) self.visit_compound_statement(node.body)
@@ -821,50 +824,6 @@ class Kernel:
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream,
self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
# Compile to ttir, for the propose of testing MLIR rewriting
def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
# TODO: share code with _compile & __call__
# preparing args
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes
attributes = dict()
for i, arg in enumerate(wargs):
if i in self.fn.do_not_specialize:
continue
if isinstance(arg, int):
attributes[i] = Kernel.pow2_divisor(arg)
elif i in tensor_idxs:
addr = arg.data_ptr()
range_size = _triton.runtime.get_pointer_range_size(addr)
attributes[i] = min(Kernel.pow2_divisor(addr),
Kernel.pow2_divisor(range_size))
# transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
# create IR module
context = _triton.ir.context()
context.load_triton()
# get just-in-time proto-type of kernel
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
ret_type = triton.language.void
prototype = triton.language.function_type([ret_type], arg_types)
# generate Triton-IR
# export symbols visible from self into code-generator object
gscope = self.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
try:
generator.visit(self.parse())
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.src, node) from e
return generator.module
class Launcher: class Launcher:
def __init__(self, kernel, grid): def __init__(self, kernel, grid):
@@ -1209,6 +1168,53 @@ class JITFunction:
raise OutOfResources(shared_mem, max_shared_memory, "shared memory") raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
return Binary(backend, name, asm, shared_mem, num_warps) return Binary(backend, name, asm, shared_mem, num_warps)
# Compile to ttir, for the propose of testing MLIR rewriting
def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
# TODO: share code with _compile & __call__
# preparing args
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes
attributes = dict()
for i, arg in enumerate(wargs):
if isinstance(arg, int):
attributes[i] = Kernel.pow2_divisor(arg)
elif i in tensor_idxs:
addr = arg.data_ptr()
range_size = _triton.runtime.get_pointer_range_size(addr)
attributes[i] = min(Kernel.pow2_divisor(addr),
Kernel.pow2_divisor(range_size))
# transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
print(f'wargs: {wargs}')
print(f'constants: {constants}')
print(f'arg_types: {arg_types}')
# create IR module
context = _triton.ir.context()
context.load_triton()
# get just-in-time proto-type of kernel
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
ret_type = triton.language.void
prototype = triton.language.function_type([ret_type], arg_types)
# generate Triton-IR
# export symbols visible from self into code-generator object
gscope = self.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
try:
generator.visit(self.parse())
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.src, node) from e
# FIXME: now we need to return context, otherwise it will be deleted
return generator.module, context
def __getitem__(self, grid): def __getitem__(self, grid):
return Launcher(self._init_kernel(), grid) return Launcher(self._init_kernel(), grid)

View File

@@ -280,7 +280,7 @@ class function_type(dtype):
self.param_types = param_types self.param_types = param_types
def __str__(self): def __str__(self):
return f'fn ({self.param_types}) -> {self.ret_type}' return f'fn ({self.param_types}) -> {self.ret_types}'
def to_ir(self, builder: ir.builder): def to_ir(self, builder: ir.builder):
ir_param_types = [ty.to_ir(builder) for ty in self.param_types] ir_param_types = [ty.to_ir(builder) for ty in self.param_types]

View File

@@ -121,7 +121,7 @@ def add(input: tl.tensor,
if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
input, other = other, input input, other = other, input
if input_scalar_ty.is_ptr(): if input_scalar_ty.is_ptr():
return tl.tensor(builder.create_gep(input.handle, [other.handle]), input.type) return tl.tensor(builder.create_gep(input.handle, other.handle), input.type)
# float + float # float + float
elif input_scalar_ty.is_floating(): elif input_scalar_ty.is_floating():
return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
@@ -138,7 +138,7 @@ def sub(input: tl.tensor,
scalar_ty = input.type.scalar scalar_ty = input.type.scalar
# ptr - offset # ptr - offset
if scalar_ty.is_ptr(): if scalar_ty.is_ptr():
return tl.tensor(builder.create_gep(input.handle, [minus(other, builder).handle]), return tl.tensor(builder.create_gep(input.handle, minus(other, builder).handle),
input.type) input.type)
# float - float # float - float
if scalar_ty.is_floating(): if scalar_ty.is_floating():
@@ -438,7 +438,7 @@ def not_equal(input: tl.tensor,
def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
shape = [end - start] shape = [end - start]
ret_ty = tl.block_type(tl.int32, shape) ret_ty = tl.block_type(tl.int32, shape)
return tl.tensor(builder.get_range(start, end), ret_ty) return tl.tensor(builder.create_make_range(start, end), ret_ty)
def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor: def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor: