[Analysis] Added Axis Info Analysis (#8)
This commit is contained in:
@@ -788,6 +788,11 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
|
||||
return self.addEntryBlock();
|
||||
}, ret::reference)
|
||||
.def("set_arg_attr", [](mlir::FuncOp &self, int arg_no, const std::string& name, int val){
|
||||
// set arg attributes "name" to value "val"
|
||||
auto attrTy = mlir::IntegerType::get(self.getContext(), 32);
|
||||
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
|
||||
}, ret::reference)
|
||||
.def("reset_type", &mlir::FuncOp::setType)
|
||||
;
|
||||
|
||||
@@ -1265,9 +1270,10 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_splat", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto argType = arg.getType();
|
||||
return self.create<mlir::triton::BroadcastOp>(
|
||||
auto ret = self.createOrFold<mlir::triton::SplatOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, argType), arg
|
||||
);
|
||||
return ret;
|
||||
})
|
||||
// // atomic
|
||||
.def("create_atomic_cas", [](mlir::OpBuilder &self, mlir::Value &ptr,
|
||||
@@ -1337,6 +1343,12 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
|
||||
return mlir::succeeded(self.run(mod.getOperation()));
|
||||
})
|
||||
.def("add_sccp_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createSCCPPass());
|
||||
})
|
||||
.def("add_symbol_dce_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createSymbolDCEPass());
|
||||
})
|
||||
.def("add_inliner_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createInlinerPass());
|
||||
})
|
||||
|
@@ -199,14 +199,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
arg_values.append(cst)
|
||||
else:
|
||||
pass
|
||||
# TODO: ...
|
||||
# if i in self.attributes:
|
||||
# is_ptr = fn.args[idx].type.is_ptr()
|
||||
# attr = 'aligned' if is_ptr else 'multiple_of'
|
||||
# attr = getattr(_triton.ir.attribute_kind, attr)
|
||||
# attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||
# fn.add_attr(idx + 1, attr)
|
||||
# fn.args[idx].name = arg_name
|
||||
if i in self.attributes:
|
||||
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i])
|
||||
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
|
||||
idx += 1
|
||||
|
||||
@@ -1307,8 +1301,13 @@ class JITFunction:
|
||||
raise CompilationError(self.src, node) from e
|
||||
# cache num_warps & num_stages
|
||||
self.num_warps, self.num_stages = num_warps, num_stages
|
||||
# run simple SCCP and DCE here to clean-up the generated IR
|
||||
mod = generator.module
|
||||
pm = _triton.ir.pass_manager(context)
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.run(mod)
|
||||
# FIXME: now we need to return context, otherwise it will be deleted
|
||||
return generator.module, context
|
||||
return mod, context
|
||||
|
||||
def compile_ttir_to_llir(self, mod, ctx):
|
||||
num_warps, num_stages = self.num_warps, self.num_stages
|
||||
|
Reference in New Issue
Block a user