Add set_attr(...) to ir.OpState
This commit is contained in:
@@ -647,6 +647,13 @@ void init_triton_ir(py::module &&m) {
|
||||
;
|
||||
|
||||
py::class_<mlir::Value>(m, "value")
|
||||
.def("set_attr", [](mlir::Value &self, std::string &name, mlir::Attribute &attr) -> void {
|
||||
if (mlir::Operation *definingOp = self.getDefiningOp())
|
||||
definingOp->setAttr(name, attr);
|
||||
else {
|
||||
/* issue an warning */
|
||||
}
|
||||
})
|
||||
;
|
||||
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement")
|
||||
;
|
||||
@@ -701,11 +708,15 @@ void init_triton_ir(py::module &&m) {
|
||||
// .value("retune", eattr::retune)
|
||||
// .value("not_implemented", eattr::not_implemented);
|
||||
|
||||
// py::class_<mlir::Attribute>(m, "attribute");
|
||||
// // .def(py::init<eattr, int>());
|
||||
py::class_<mlir::Attribute>(m, "attribute");
|
||||
py::class_<mlir::IntegerAttr, mlir::Attribute>(m, "integer_attr");
|
||||
py::class_<mlir::BoolAttr, mlir::Attribute>(m, "bool_attr");
|
||||
|
||||
// Ops
|
||||
py::class_<mlir::OpState>(m, "OpState")
|
||||
.def("set_attr", [](mlir::OpState &self, std::string &name, mlir::Attribute &attr) -> void {
|
||||
self->setAttr(name, attr);
|
||||
})
|
||||
.def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value {
|
||||
return self->getResult(idx);
|
||||
})
|
||||
@@ -744,7 +755,6 @@ void init_triton_ir(py::module &&m) {
|
||||
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
|
||||
|
||||
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module")
|
||||
// .def("set_attr")
|
||||
.def("dump", [](mlir::ModuleOp &self) -> void {
|
||||
self.dump();
|
||||
})
|
||||
@@ -774,9 +784,7 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::ModuleOp>(loc);
|
||||
})
|
||||
// // control flow
|
||||
// .def("br", &ir::builder::create_br, ret::reference)
|
||||
// .def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||
// control flow
|
||||
.def("ret_void", [](mlir::OpBuilder &self) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::ReturnOp>(loc);
|
||||
@@ -805,6 +813,9 @@ void init_triton_ir(py::module &&m) {
|
||||
// self->set_insert_point(bb);
|
||||
// }
|
||||
// })
|
||||
// Attr
|
||||
.def("get_bool_attr", &mlir::OpBuilder::getBoolAttr)
|
||||
.def("get_int32_attr", &mlir::OpBuilder::getI32IntegerAttr)
|
||||
// Use arith.ConstantOp to create constants
|
||||
// // Constants
|
||||
// .def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||
@@ -1215,10 +1226,11 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
.def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>().getElementType();
|
||||
return self.create<mlir::triton::BroadcastOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, argType), arg
|
||||
);
|
||||
if (auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>())
|
||||
return self.create<mlir::triton::BroadcastOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, argType.getElementType()), arg
|
||||
);
|
||||
throw std::runtime_error("arg is not of RankedTensorType, use create_splat");
|
||||
})
|
||||
.def("create_splat", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
|
@@ -23,6 +23,7 @@ func_ty = builder.get_function_ty([f16_ptr_ty, f16_ptr_ty, f16_ptr_ty], [])
|
||||
func = builder.create_function('foo', func_ty)
|
||||
|
||||
module.push_back(func)
|
||||
module.set_attr("num_warps", builder.get_int32_attr(4))
|
||||
|
||||
# ...
|
||||
entry = func.add_entry_block()
|
||||
@@ -31,23 +32,23 @@ offsets = builder.create_make_range(0, 128)
|
||||
pid = builder.create_get_program_id(0)
|
||||
_128 = builder.get_int32(128)
|
||||
offset = builder.create_add(pid, _128)
|
||||
offset = builder.create_broadcast(offset, [128])
|
||||
offset = builder.create_splat(offset, [128])
|
||||
offsets = builder.create_add(offset, offsets)
|
||||
|
||||
|
||||
a_ptrs = builder.create_broadcast(entry.arg(0), [128])
|
||||
b_ptrs = builder.create_broadcast(entry.arg(1), [128])
|
||||
a_ptrs = builder.create_splat(entry.arg(0), [128])
|
||||
b_ptrs = builder.create_splat(entry.arg(1), [128])
|
||||
|
||||
a_ptrs = builder.create_gep(a_ptrs, offsets)
|
||||
b_ptrs = builder.create_gep(b_ptrs, offsets)
|
||||
|
||||
a = builder.create_load(a_ptrs)
|
||||
b = builder.create_load(b_ptrs)
|
||||
a = builder.create_load(a_ptrs, ir.CACHE_MODIFIER.NONE, ir.EVICTION_POLICY.NORMAL, False)
|
||||
b = builder.create_load(b_ptrs, ir.CACHE_MODIFIER.NONE, ir.EVICTION_POLICY.NORMAL, False)
|
||||
|
||||
c = builder.create_fadd(a, b)
|
||||
# c.set_attr("ieee_rounding", builder.get_bool_attr(True))
|
||||
c.set_attr("ieee_rounding", builder.get_bool_attr(True))
|
||||
|
||||
c_ptrs = builder.create_broadcast(entry.arg(2), [128])
|
||||
c_ptrs = builder.create_splat(entry.arg(2), [128])
|
||||
c_ptrs = builder.create_gep(c_ptrs, offsets)
|
||||
builder.create_store(c_ptrs, c)
|
||||
|
||||
|
Reference in New Issue
Block a user