[Triton-IR] Added type inference and verifier for Triton-IR operations (#767)

This commit is contained in:
Philippe Tillet
2022-10-11 18:16:41 -07:00
committed by GitHub
parent b6e5a231e5
commit 623c99609f
27 changed files with 494 additions and 348 deletions

View File

@@ -185,9 +185,16 @@ struct TritonExpandDimsPattern
// return type
RankedTensorType retType =
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
// convert operand to slice of return type
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
getContext(), op.axis(), retEncoding);
RankedTensorType newArgType = RankedTensorType::get(
argType.getShape(), argType.getElementType(), newArgEncoding);
// construct new op
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(
op, retType, adaptor.src(), adaptor.axis());
auto newSrc = rewriter.create<triton::gpu::ConvertLayoutOp>(
op.getLoc(), newArgType, adaptor.src());
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(op, newSrc,
adaptor.axis());
return success();
}
};
@@ -310,9 +317,8 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis());
op, adaptor.redOp(), adaptor.operand(), adaptor.axis());
return success();
}
};