diff --git a/python/src/triton.cc b/python/src/triton.cc index 0cb133f9f..ee7b4877a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -647,6 +647,13 @@ void init_triton_ir(py::module &&m) { ; py::class_(m, "value") + .def("set_attr", [](mlir::Value &self, std::string &name, mlir::Attribute &attr) -> void { + if (mlir::Operation *definingOp = self.getDefiningOp()) + definingOp->setAttr(name, attr); + else { + /* issue an warning */ + } + }) ; py::class_(m, "block_arguement") ; @@ -701,11 +708,15 @@ void init_triton_ir(py::module &&m) { // .value("retune", eattr::retune) // .value("not_implemented", eattr::not_implemented); - // py::class_(m, "attribute"); - // // .def(py::init()); + py::class_(m, "attribute"); + py::class_(m, "integer_attr"); + py::class_(m, "bool_attr"); // Ops py::class_(m, "OpState") + .def("set_attr", [](mlir::OpState &self, std::string &name, mlir::Attribute &attr) -> void { + self->setAttr(name, attr); + }) .def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value { return self->getResult(idx); }) @@ -744,7 +755,6 @@ void init_triton_ir(py::module &&m) { py::class_(m, "CondtionOp"); py::class_(m, "module") - // .def("set_attr") .def("dump", [](mlir::ModuleOp &self) -> void { self.dump(); }) @@ -774,9 +784,7 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); return self.create(loc); }) - // // control flow - // .def("br", &ir::builder::create_br, ret::reference) - // .def("cond_br", &ir::builder::create_cond_br, ret::reference) + // control flow .def("ret_void", [](mlir::OpBuilder &self) { auto loc = self.getUnknownLoc(); self.create(loc); @@ -805,6 +813,9 @@ void init_triton_ir(py::module &&m) { // self->set_insert_point(bb); // } // }) + // Attr + .def("get_bool_attr", &mlir::OpBuilder::getBoolAttr) + .def("get_int32_attr", &mlir::OpBuilder::getI32IntegerAttr) // Use arith.ConstantOp to create constants // // Constants // .def("get_int1", &ir::builder::get_int1, ret::reference) @@ -1215,10 +1226,11 @@ void init_triton_ir(py::module &&m) { }) .def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector &shape) -> mlir::Value { auto loc = self.getUnknownLoc(); - auto argType = arg.getType().dyn_cast().getElementType(); - return self.create( - loc, mlir::RankedTensorType::get(shape, argType), arg - ); + if (auto argType = arg.getType().dyn_cast()) + return self.create( + loc, mlir::RankedTensorType::get(shape, argType.getElementType()), arg + ); + throw std::runtime_error("arg is not of RankedTensorType, use create_splat"); }) .def("create_splat", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector &shape) -> mlir::Value { auto loc = self.getUnknownLoc(); diff --git a/rewrite-test/test_ir.py b/rewrite-test/test_ir.py index a3077f4d5..3ecdbdfd2 100644 --- a/rewrite-test/test_ir.py +++ b/rewrite-test/test_ir.py @@ -23,6 +23,7 @@ func_ty = builder.get_function_ty([f16_ptr_ty, f16_ptr_ty, f16_ptr_ty], []) func = builder.create_function('foo', func_ty) module.push_back(func) +module.set_attr("num_warps", builder.get_int32_attr(4)) # ... entry = func.add_entry_block() @@ -31,23 +32,23 @@ offsets = builder.create_make_range(0, 128) pid = builder.create_get_program_id(0) _128 = builder.get_int32(128) offset = builder.create_add(pid, _128) -offset = builder.create_broadcast(offset, [128]) +offset = builder.create_splat(offset, [128]) offsets = builder.create_add(offset, offsets) -a_ptrs = builder.create_broadcast(entry.arg(0), [128]) -b_ptrs = builder.create_broadcast(entry.arg(1), [128]) +a_ptrs = builder.create_splat(entry.arg(0), [128]) +b_ptrs = builder.create_splat(entry.arg(1), [128]) a_ptrs = builder.create_gep(a_ptrs, offsets) b_ptrs = builder.create_gep(b_ptrs, offsets) -a = builder.create_load(a_ptrs) -b = builder.create_load(b_ptrs) +a = builder.create_load(a_ptrs, ir.CACHE_MODIFIER.NONE, ir.EVICTION_POLICY.NORMAL, False) +b = builder.create_load(b_ptrs, ir.CACHE_MODIFIER.NONE, ir.EVICTION_POLICY.NORMAL, False) c = builder.create_fadd(a, b) -# c.set_attr("ieee_rounding", builder.get_bool_attr(True)) +c.set_attr("ieee_rounding", builder.get_bool_attr(True)) -c_ptrs = builder.create_broadcast(entry.arg(2), [128]) +c_ptrs = builder.create_splat(entry.arg(2), [128]) c_ptrs = builder.create_gep(c_ptrs, offsets) builder.create_store(c_ptrs, c)