[TRITON-MLIR][FRONTEND]fix scf.if to run through layernorm tutorial (#938)

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-12-02 17:45:29 +08:00
committed by GitHub
parent c280ebda1b
commit 521ff9ad74
2 changed files with 56 additions and 3 deletions

View File

@@ -456,10 +456,55 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
}
};
// This is borrowed from ConvertFIfOpTypes in
// SCF/Transforms/StructuralTypeConversions.cpp
class SCFIfPattern : public OpConversionPattern<scf::IfOp> {
public:
using OpConversionPattern<scf::IfOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(scf::IfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO: Generalize this to any type conversion, not just 1:1.
//
// We need to implement something more sophisticated here that tracks which
// types convert to which other types and does the appropriate
// materialization logic.
// For example, it's possible that one result type converts to 0 types and
// another to 2 types, so newResultTypes would at least be the right size to
// not crash in the llvm::zip call below, but then we would set the the
// wrong type on the SSA values! These edge cases are also why we cannot
// safely use the TypeConverter::convertTypes helper here.
SmallVector<Type> newResultTypes;
for (auto type : op.getResultTypes()) {
Type newType = typeConverter->convertType(type);
if (!newType)
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
newResultTypes.push_back(newType);
}
// See comments in the ForOp pattern for why we clone without regions and
// then inline.
scf::IfOp newOp =
cast<scf::IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
newOp.getThenRegion().end());
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
newOp.getElseRegion().end());
// Update the operands and types.
newOp->setOperands(adaptor.getOperands());
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
std::get<0>(t).setType(std::get<1>(t));
rewriter.replaceOp(op, newOp.getResults());
return success();
}
};
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<SCFYieldPattern, SCFForPattern>(typeConverter, context);
patterns.add<SCFYieldPattern, SCFForPattern, SCFIfPattern>(typeConverter,
context);
}
class ConvertTritonToTritonGPU