[FRONTEND][BACKEND] Added trans instruction; made flash attention bwd pass work (#943)

This commit is contained in:
Philippe Tillet
2022-12-03 09:58:24 -08:00
committed by GitHub
parent 4d64589b22
commit 8edfe813a5
12 changed files with 310 additions and 143 deletions

View File

@@ -1085,6 +1085,16 @@ void init_triton_ir(py::module &&m) {
mlir::RankedTensorType::get(shape, lhsType.getElementType()),
lhs, rhs);
})
.def("create_trans",
[](mlir::OpBuilder &self, mlir::Value &arg) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>();
auto argEltType = argType.getElementType();
std::vector<int64_t> retShape = argType.getShape();
std::reverse(retShape.begin(), retShape.end());
return self.create<mlir::triton::TransOp>(
loc, mlir::RankedTensorType::get(retShape, argEltType), arg);
})
.def("create_broadcast",
[](mlir::OpBuilder &self, mlir::Value &arg,
std::vector<int64_t> &shape) -> mlir::Value {