[Triton-IR] Added type inference and verifier for Triton-IR operations (#767)
This commit is contained in:
@@ -185,9 +185,16 @@ struct TritonExpandDimsPattern
|
||||
// return type
|
||||
RankedTensorType retType =
|
||||
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
|
||||
// convert operand to slice of return type
|
||||
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
|
||||
getContext(), op.axis(), retEncoding);
|
||||
RankedTensorType newArgType = RankedTensorType::get(
|
||||
argType.getShape(), argType.getElementType(), newArgEncoding);
|
||||
// construct new op
|
||||
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(
|
||||
op, retType, adaptor.src(), adaptor.axis());
|
||||
auto newSrc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op.getLoc(), newArgType, adaptor.src());
|
||||
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(op, newSrc,
|
||||
adaptor.axis());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -310,9 +317,8 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
||||
op, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
Reference in New Issue
Block a user