[FRONTEND] Added ExpandDimsOp
primitive (#36)
This commit is contained in:
@@ -1471,7 +1471,6 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
|
||||
})
|
||||
// Block instruction
|
||||
.def("create_view",
|
||||
[](mlir::OpBuilder &self, mlir::Value &arg,
|
||||
std::vector<int64_t> &shape) -> mlir::Value {
|
||||
@@ -1482,6 +1481,18 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::triton::ViewOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, argType), arg);
|
||||
})
|
||||
.def(
|
||||
"create_expand_dims",
|
||||
[](mlir::OpBuilder &self, mlir::Value &arg, int axis) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto argEltType = argType.getElementType();
|
||||
std::vector<int64_t> retShape = argType.getShape();
|
||||
retShape.insert(retShape.begin() + axis, 1);
|
||||
return self.create<mlir::triton::ExpandDimsOp>(
|
||||
loc, mlir::RankedTensorType::get(retShape, argEltType), arg,
|
||||
axis);
|
||||
})
|
||||
.def("create_cat",
|
||||
[](mlir::OpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
|
Reference in New Issue
Block a user