[Triton-MLIR] Keren/code gen for extract slice and alloc tensor (#692)

Co-authored-by: gzhu <goostavz@outlook.com>
This commit is contained in:
Keren Zhou
2022-09-23 12:38:14 -07:00
committed by GitHub
parent c56f0198dd
commit ecd1bc33df
5 changed files with 134 additions and 58 deletions

View File

@@ -431,9 +431,10 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
if (axis < 0 || axis > srcShape.size())
return failure();
// Since we only extract a slice from a certain index on the axis,
// the dims before the axis can be dropped.
auto dstShape = srcShape.drop_front(axis + 1);
SmallVector<int64_t, 4> dstShape;
for (int i = 0; i < srcShape.size(); i++)
if (i != axis)
dstShape.push_back(srcShape[i]);
auto returnType =
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
inferredReturnTypes.assign({returnType});