[FRONTEND] Added ExpandDimsOp primitive (#36)

This commit is contained in:
Philippe Tillet
2022-08-04 18:41:06 -07:00
committed by GitHub
parent a7b49b3227
commit 78ebbe24c7
8 changed files with 98 additions and 41 deletions

View File

@@ -1471,7 +1471,6 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
})
// Block instruction
.def("create_view",
[](mlir::OpBuilder &self, mlir::Value &arg,
std::vector<int64_t> &shape) -> mlir::Value {
@@ -1482,6 +1481,18 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::triton::ViewOp>(
loc, mlir::RankedTensorType::get(shape, argType), arg);
})
.def(
"create_expand_dims",
[](mlir::OpBuilder &self, mlir::Value &arg, int axis) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>();
auto argEltType = argType.getElementType();
std::vector<int64_t> retShape = argType.getShape();
retShape.insert(retShape.begin() + axis, 1);
return self.create<mlir::triton::ExpandDimsOp>(
loc, mlir::RankedTensorType::get(retShape, argEltType), arg,
axis);
})
.def("create_cat",
[](mlir::OpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {

View File

@@ -556,18 +556,22 @@ class tensor:
def __getitem__(self, slices, _builder=None):
if isinstance(slices, slice):
slices = [slices]
src_shape = self.shape
dst_shape = []
curr = 0
for sl in slices:
ret = self
n_inserted = 0
for dim, sl in enumerate(slices):
if isinstance(sl, constexpr) and sl.value is None:
dst_shape.append(1)
ret = semantic.expand_dims(ret, dim + n_inserted, _builder)
n_inserted += 1
elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr].value)
curr += 1
ret = semantic.view(self, dst_shape, _builder)
pass
else:
assert False, "unsupported"
return ret
# x[:, None, :, None]
# x = expand_dims(x, axis=1)
# x = expand_dims(x, axis=2)
@builtin
def to(self, dtype, bitcast=False, _builder=None):
if isinstance(bitcast, constexpr):

View File

@@ -463,6 +463,13 @@ def view(input: tl.tensor,
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
dst_shape = [s for s in input.type.shape]
dst_shape.insert(axis, 1)
ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
# TODO: check types
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type)