[BACKEND] Make flash attention forward pass work (#928)

This also simplifies BroadcastOp codegen
This commit is contained in:
Philippe Tillet
2022-11-30 11:13:24 +01:00
committed by GitHub
parent 4e6a8209ed
commit 6461254fb5
7 changed files with 326 additions and 205 deletions

View File

@@ -11,6 +11,7 @@
#include "mlir/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
@@ -115,6 +116,10 @@ void init_triton_ir(py::module &&m) {
.def(py::init<>())
.def("load_triton", [](mlir::MLIRContext &self) {
self.getOrLoadDialect<mlir::triton::TritonDialect>();
// we load LLVM because the frontend uses LLVM.undef for
// some placeholders
self.getOrLoadDialect<mlir::triton::TritonDialect>();
self.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
});
// .def(py::init([](){
// mlir::MLIRContext context;
@@ -350,6 +355,7 @@ void init_triton_ir(py::module &&m) {
"parse_mlir_module",
[](const std::string &inputFilename, mlir::MLIRContext &context) {
// initialize registry
// note: we initialize llvm for undef
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect,
mlir::triton::gpu::TritonGPUDialect,
@@ -1243,7 +1249,14 @@ void init_triton_ir(py::module &&m) {
mlir::StringAttr::get(self.getContext(),
llvm::StringRef(prefix)),
values);
});
})
// Undef
.def("create_undef",
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<::mlir::LLVM::UndefOp>(loc, type);
})
;
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())