[Analysis] Added Axis Info Analysis (#8)

This commit is contained in:
Philippe Tillet
2022-07-19 13:38:48 -07:00
committed by GitHub
parent df940aaab0
commit a633d2b403
20 changed files with 582 additions and 13 deletions

View File

@@ -788,6 +788,11 @@ void init_triton_ir(py::module &&m) {
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
return self.addEntryBlock();
}, ret::reference)
.def("set_arg_attr", [](mlir::FuncOp &self, int arg_no, const std::string& name, int val){
// set arg attributes "name" to value "val"
auto attrTy = mlir::IntegerType::get(self.getContext(), 32);
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
}, ret::reference)
.def("reset_type", &mlir::FuncOp::setType)
;
@@ -1265,9 +1270,10 @@ void init_triton_ir(py::module &&m) {
.def("create_splat", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto argType = arg.getType();
return self.create<mlir::triton::BroadcastOp>(
auto ret = self.createOrFold<mlir::triton::SplatOp>(
loc, mlir::RankedTensorType::get(shape, argType), arg
);
return ret;
})
// // atomic
.def("create_atomic_cas", [](mlir::OpBuilder &self, mlir::Value &ptr,
@@ -1337,6 +1343,12 @@ void init_triton_ir(py::module &&m) {
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
return mlir::succeeded(self.run(mod.getOperation()));
})
.def("add_sccp_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createSCCPPass());
})
.def("add_symbol_dce_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createSymbolDCEPass());
})
.def("add_inliner_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createInlinerPass());
})

View File

@@ -199,14 +199,8 @@ class CodeGenerator(ast.NodeVisitor):
arg_values.append(cst)
else:
pass
# TODO: ...
# if i in self.attributes:
# is_ptr = fn.args[idx].type.is_ptr()
# attr = 'aligned' if is_ptr else 'multiple_of'
# attr = getattr(_triton.ir.attribute_kind, attr)
# attr = _triton.ir.attribute(attr, self.attributes[i])
# fn.add_attr(idx + 1, attr)
# fn.args[idx].name = arg_name
if i in self.attributes:
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i])
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
idx += 1
@@ -1307,8 +1301,13 @@ class JITFunction:
raise CompilationError(self.src, node) from e
# cache num_warps & num_stages
self.num_warps, self.num_stages = num_warps, num_stages
# run simple SCCP and DCE here to clean-up the generated IR
mod = generator.module
pm = _triton.ir.pass_manager(context)
pm.add_canonicalizer_pass()
pm.run(mod)
# FIXME: now we need to return context, otherwise it will be deleted
return generator.module, context
return mod, context
def compile_ttir_to_llir(self, mod, ctx):
num_warps, num_stages = self.num_warps, self.num_stages