Add set_attr(...) to ir.OpState

This commit is contained in:
Yan Da
2022-04-11 12:26:54 +08:00
parent 4eb062f313
commit 7e0fd97965
2 changed files with 30 additions and 17 deletions

View File

@@ -647,6 +647,13 @@ void init_triton_ir(py::module &&m) {
;
py::class_<mlir::Value>(m, "value")
.def("set_attr", [](mlir::Value &self, std::string &name, mlir::Attribute &attr) -> void {
if (mlir::Operation *definingOp = self.getDefiningOp())
definingOp->setAttr(name, attr);
else {
/* issue an warning */
}
})
;
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement")
;
@@ -701,11 +708,15 @@ void init_triton_ir(py::module &&m) {
// .value("retune", eattr::retune)
// .value("not_implemented", eattr::not_implemented);
// py::class_<mlir::Attribute>(m, "attribute");
// // .def(py::init<eattr, int>());
py::class_<mlir::Attribute>(m, "attribute");
py::class_<mlir::IntegerAttr, mlir::Attribute>(m, "integer_attr");
py::class_<mlir::BoolAttr, mlir::Attribute>(m, "bool_attr");
// Ops
py::class_<mlir::OpState>(m, "OpState")
.def("set_attr", [](mlir::OpState &self, std::string &name, mlir::Attribute &attr) -> void {
self->setAttr(name, attr);
})
.def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value {
return self->getResult(idx);
})
@@ -744,7 +755,6 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module")
// .def("set_attr")
.def("dump", [](mlir::ModuleOp &self) -> void {
self.dump();
})
@@ -774,9 +784,7 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::ModuleOp>(loc);
})
// // control flow
// .def("br", &ir::builder::create_br, ret::reference)
// .def("cond_br", &ir::builder::create_cond_br, ret::reference)
// control flow
.def("ret_void", [](mlir::OpBuilder &self) {
auto loc = self.getUnknownLoc();
self.create<mlir::ReturnOp>(loc);
@@ -805,6 +813,9 @@ void init_triton_ir(py::module &&m) {
// self->set_insert_point(bb);
// }
// })
// Attr
.def("get_bool_attr", &mlir::OpBuilder::getBoolAttr)
.def("get_int32_attr", &mlir::OpBuilder::getI32IntegerAttr)
// Use arith.ConstantOp to create constants
// // Constants
// .def("get_int1", &ir::builder::get_int1, ret::reference)
@@ -1215,10 +1226,11 @@ void init_triton_ir(py::module &&m) {
})
.def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>().getElementType();
return self.create<mlir::triton::BroadcastOp>(
loc, mlir::RankedTensorType::get(shape, argType), arg
);
if (auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>())
return self.create<mlir::triton::BroadcastOp>(
loc, mlir::RankedTensorType::get(shape, argType.getElementType()), arg
);
throw std::runtime_error("arg is not of RankedTensorType, use create_splat");
})
.def("create_splat", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
auto loc = self.getUnknownLoc();

View File

@@ -23,6 +23,7 @@ func_ty = builder.get_function_ty([f16_ptr_ty, f16_ptr_ty, f16_ptr_ty], [])
func = builder.create_function('foo', func_ty)
module.push_back(func)
module.set_attr("num_warps", builder.get_int32_attr(4))
# ...
entry = func.add_entry_block()
@@ -31,23 +32,23 @@ offsets = builder.create_make_range(0, 128)
pid = builder.create_get_program_id(0)
_128 = builder.get_int32(128)
offset = builder.create_add(pid, _128)
offset = builder.create_broadcast(offset, [128])
offset = builder.create_splat(offset, [128])
offsets = builder.create_add(offset, offsets)
a_ptrs = builder.create_broadcast(entry.arg(0), [128])
b_ptrs = builder.create_broadcast(entry.arg(1), [128])
a_ptrs = builder.create_splat(entry.arg(0), [128])
b_ptrs = builder.create_splat(entry.arg(1), [128])
a_ptrs = builder.create_gep(a_ptrs, offsets)
b_ptrs = builder.create_gep(b_ptrs, offsets)
a = builder.create_load(a_ptrs)
b = builder.create_load(b_ptrs)
a = builder.create_load(a_ptrs, ir.CACHE_MODIFIER.NONE, ir.EVICTION_POLICY.NORMAL, False)
b = builder.create_load(b_ptrs, ir.CACHE_MODIFIER.NONE, ir.EVICTION_POLICY.NORMAL, False)
c = builder.create_fadd(a, b)
# c.set_attr("ieee_rounding", builder.get_bool_attr(True))
c.set_attr("ieee_rounding", builder.get_bool_attr(True))
c_ptrs = builder.create_broadcast(entry.arg(2), [128])
c_ptrs = builder.create_splat(entry.arg(2), [128])
c_ptrs = builder.create_gep(c_ptrs, offsets)
builder.create_store(c_ptrs, c)