[FRONTEND] Added ExpandDimsOp
primitive (#36)
This commit is contained in:
@@ -1471,7 +1471,6 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
|
||||
})
|
||||
// Block instruction
|
||||
.def("create_view",
|
||||
[](mlir::OpBuilder &self, mlir::Value &arg,
|
||||
std::vector<int64_t> &shape) -> mlir::Value {
|
||||
@@ -1482,6 +1481,18 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::triton::ViewOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, argType), arg);
|
||||
})
|
||||
.def(
|
||||
"create_expand_dims",
|
||||
[](mlir::OpBuilder &self, mlir::Value &arg, int axis) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto argEltType = argType.getElementType();
|
||||
std::vector<int64_t> retShape = argType.getShape();
|
||||
retShape.insert(retShape.begin() + axis, 1);
|
||||
return self.create<mlir::triton::ExpandDimsOp>(
|
||||
loc, mlir::RankedTensorType::get(retShape, argEltType), arg,
|
||||
axis);
|
||||
})
|
||||
.def("create_cat",
|
||||
[](mlir::OpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
|
@@ -556,18 +556,22 @@ class tensor:
|
||||
def __getitem__(self, slices, _builder=None):
|
||||
if isinstance(slices, slice):
|
||||
slices = [slices]
|
||||
src_shape = self.shape
|
||||
dst_shape = []
|
||||
curr = 0
|
||||
for sl in slices:
|
||||
ret = self
|
||||
n_inserted = 0
|
||||
for dim, sl in enumerate(slices):
|
||||
if isinstance(sl, constexpr) and sl.value is None:
|
||||
dst_shape.append(1)
|
||||
ret = semantic.expand_dims(ret, dim + n_inserted, _builder)
|
||||
n_inserted += 1
|
||||
elif sl == slice(None, None, None):
|
||||
dst_shape.append(src_shape[curr].value)
|
||||
curr += 1
|
||||
ret = semantic.view(self, dst_shape, _builder)
|
||||
pass
|
||||
else:
|
||||
assert False, "unsupported"
|
||||
return ret
|
||||
|
||||
# x[:, None, :, None]
|
||||
# x = expand_dims(x, axis=1)
|
||||
# x = expand_dims(x, axis=2)
|
||||
|
||||
@builtin
|
||||
def to(self, dtype, bitcast=False, _builder=None):
|
||||
if isinstance(bitcast, constexpr):
|
||||
|
@@ -463,6 +463,13 @@ def view(input: tl.tensor,
|
||||
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
|
||||
|
||||
|
||||
def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
dst_shape = [s for s in input.type.shape]
|
||||
dst_shape.insert(axis, 1)
|
||||
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
||||
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
|
||||
|
||||
|
||||
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
# TODO: check types
|
||||
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type)
|
||||
|
Reference in New Issue
Block a user