Introducing SCF
This commit is contained in:
@@ -6,6 +6,7 @@
|
|||||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
// #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
// #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||||
|
|
||||||
#include "triton/ir/Dialect.h.inc"
|
#include "triton/ir/Dialect.h.inc"
|
||||||
|
@@ -17,4 +17,5 @@ add_mlir_dialect_library(TritonIR
|
|||||||
MLIRStandard
|
MLIRStandard
|
||||||
|
|
||||||
MLIRTensor
|
MLIRTensor
|
||||||
|
MLIRSCF
|
||||||
)
|
)
|
||||||
|
@@ -834,13 +834,17 @@ void init_triton_ir(py::module &&m) {
|
|||||||
}
|
}
|
||||||
throw std::runtime_error("invalid function type");
|
throw std::runtime_error("invalid function type");
|
||||||
})
|
})
|
||||||
// // Structured control flow
|
// Structured control flow
|
||||||
// .def("create_scf_for", [](mlir::OpBuilder &self) {
|
.def("create_for", [](mlir::OpBuilder &self, MlirValue &lb, MlirValue &ub,
|
||||||
// return self.create<mlir::scf::ForOp>(/*fill this*/);
|
MlirValue &step) {
|
||||||
// })
|
auto loc = self.getUnknownLoc();
|
||||||
// .def("create_scf_yield")
|
return wrap(
|
||||||
// .def("create_scf_if")
|
self.create<mlir::scf::ForOp>(loc, unwrap(lb), unwrap(ub), unwrap(step))
|
||||||
// .def("create_scf_while")
|
);
|
||||||
|
})
|
||||||
|
// .def("create_yield")
|
||||||
|
// .def("create_if")
|
||||||
|
// .def("create_while")
|
||||||
|
|
||||||
// miscellious
|
// miscellious
|
||||||
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> MlirValue {
|
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> MlirValue {
|
||||||
|
Reference in New Issue
Block a user