[BACKEND] Make flash attention forward pass work (#928)
This also simplifies BroadcastOp codegen
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
@@ -115,6 +116,10 @@ void init_triton_ir(py::module &&m) {
|
||||
.def(py::init<>())
|
||||
.def("load_triton", [](mlir::MLIRContext &self) {
|
||||
self.getOrLoadDialect<mlir::triton::TritonDialect>();
|
||||
// we load LLVM because the frontend uses LLVM.undef for
|
||||
// some placeholders
|
||||
self.getOrLoadDialect<mlir::triton::TritonDialect>();
|
||||
self.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
|
||||
});
|
||||
// .def(py::init([](){
|
||||
// mlir::MLIRContext context;
|
||||
@@ -350,6 +355,7 @@ void init_triton_ir(py::module &&m) {
|
||||
"parse_mlir_module",
|
||||
[](const std::string &inputFilename, mlir::MLIRContext &context) {
|
||||
// initialize registry
|
||||
// note: we initialize llvm for undef
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<mlir::triton::TritonDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect,
|
||||
@@ -1243,7 +1249,14 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::StringAttr::get(self.getContext(),
|
||||
llvm::StringRef(prefix)),
|
||||
values);
|
||||
});
|
||||
})
|
||||
// Undef
|
||||
.def("create_undef",
|
||||
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<::mlir::LLVM::UndefOp>(loc, type);
|
||||
})
|
||||
;
|
||||
|
||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||
.def(py::init<mlir::MLIRContext *>())
|
||||
|
Reference in New Issue
Block a user