[Triton-MLIR] Keren/code gen for extract slice and alloc tensor (#692)
Co-authored-by: gzhu <goostavz@outlook.com>
This commit is contained in:
@@ -431,9 +431,10 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
|
||||
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
if (axis < 0 || axis > srcShape.size())
|
||||
return failure();
|
||||
// Since we only extract a slice from a certain index on the axis,
|
||||
// the dims before the axis can be dropped.
|
||||
auto dstShape = srcShape.drop_front(axis + 1);
|
||||
SmallVector<int64_t, 4> dstShape;
|
||||
for (int i = 0; i < srcShape.size(); i++)
|
||||
if (i != axis)
|
||||
dstShape.push_back(srcShape[i]);
|
||||
auto returnType =
|
||||
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
|
||||
inferredReturnTypes.assign({returnType});
|
||||
|
Reference in New Issue
Block a user