[FRONTEND] Added ExpandDimsOp primitive (#36)

This commit is contained in:
Philippe Tillet
2022-08-04 18:41:06 -07:00
committed by GitHub
parent a7b49b3227
commit 78ebbe24c7
8 changed files with 98 additions and 41 deletions

View File

@@ -1471,7 +1471,6 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
})
// Block instruction
.def("create_view",
[](mlir::OpBuilder &self, mlir::Value &arg,
std::vector<int64_t> &shape) -> mlir::Value {
@@ -1482,6 +1481,18 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::triton::ViewOp>(
loc, mlir::RankedTensorType::get(shape, argType), arg);
})
.def(
"create_expand_dims",
[](mlir::OpBuilder &self, mlir::Value &arg, int axis) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>();
auto argEltType = argType.getElementType();
std::vector<int64_t> retShape = argType.getShape();
retShape.insert(retShape.begin() + axis, 1);
return self.create<mlir::triton::ExpandDimsOp>(
loc, mlir::RankedTensorType::get(retShape, argEltType), arg,
axis);
})
.def("create_cat",
[](mlir::OpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {