From 0d139ec460ad9af5da94112817c52ea5301c86cc Mon Sep 17 00:00:00 2001 From: Yan Da Date: Sat, 26 Mar 2022 17:02:32 +0800 Subject: [PATCH] Introducing SCF --- include/triton/ir/Dialect.h | 1 + lib/ir/CMakeLists.txt | 1 + python/src/triton.cc | 18 +++++++++++------- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/include/triton/ir/Dialect.h b/include/triton/ir/Dialect.h index ffc89f730..f26cccac8 100644 --- a/include/triton/ir/Dialect.h +++ b/include/triton/ir/Dialect.h @@ -6,6 +6,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/SCF/SCF.h" // #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "triton/ir/Dialect.h.inc" diff --git a/lib/ir/CMakeLists.txt b/lib/ir/CMakeLists.txt index 59155ddc5..83c61e047 100644 --- a/lib/ir/CMakeLists.txt +++ b/lib/ir/CMakeLists.txt @@ -17,4 +17,5 @@ add_mlir_dialect_library(TritonIR MLIRStandard MLIRTensor + MLIRSCF ) diff --git a/python/src/triton.cc b/python/src/triton.cc index 2bbd0c67d..19c79f3b8 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -834,13 +834,17 @@ void init_triton_ir(py::module &&m) { } throw std::runtime_error("invalid function type"); }) - // // Structured control flow - // .def("create_scf_for", [](mlir::OpBuilder &self) { - // return self.create(/*fill this*/); - // }) - // .def("create_scf_yield") - // .def("create_scf_if") - // .def("create_scf_while") + // Structured control flow + .def("create_for", [](mlir::OpBuilder &self, MlirValue &lb, MlirValue &ub, + MlirValue &step) { + auto loc = self.getUnknownLoc(); + return wrap( + self.create(loc, unwrap(lb), unwrap(ub), unwrap(step)) + ); + }) + // .def("create_yield") + // .def("create_if") + // .def("create_while") // miscellious .def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> MlirValue {