[TRITON-MLIR][FRONTEND]fix scf.if to run through layernorm tutorial (#938)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -456,10 +456,55 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// This is borrowed from ConvertFIfOpTypes in
|
||||
// SCF/Transforms/StructuralTypeConversions.cpp
|
||||
class SCFIfPattern : public OpConversionPattern<scf::IfOp> {
|
||||
public:
|
||||
using OpConversionPattern<scf::IfOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::IfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// TODO: Generalize this to any type conversion, not just 1:1.
|
||||
//
|
||||
// We need to implement something more sophisticated here that tracks which
|
||||
// types convert to which other types and does the appropriate
|
||||
// materialization logic.
|
||||
// For example, it's possible that one result type converts to 0 types and
|
||||
// another to 2 types, so newResultTypes would at least be the right size to
|
||||
// not crash in the llvm::zip call below, but then we would set the the
|
||||
// wrong type on the SSA values! These edge cases are also why we cannot
|
||||
// safely use the TypeConverter::convertTypes helper here.
|
||||
SmallVector<Type> newResultTypes;
|
||||
for (auto type : op.getResultTypes()) {
|
||||
Type newType = typeConverter->convertType(type);
|
||||
if (!newType)
|
||||
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
||||
newResultTypes.push_back(newType);
|
||||
}
|
||||
|
||||
// See comments in the ForOp pattern for why we clone without regions and
|
||||
// then inline.
|
||||
scf::IfOp newOp =
|
||||
cast<scf::IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
|
||||
newOp.getThenRegion().end());
|
||||
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
|
||||
newOp.getElseRegion().end());
|
||||
|
||||
// Update the operands and types.
|
||||
newOp->setOperands(adaptor.getOperands());
|
||||
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
||||
std::get<0>(t).setType(std::get<1>(t));
|
||||
rewriter.replaceOp(op, newOp.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<SCFYieldPattern, SCFForPattern>(typeConverter, context);
|
||||
patterns.add<SCFYieldPattern, SCFForPattern, SCFIfPattern>(typeConverter,
|
||||
context);
|
||||
}
|
||||
|
||||
class ConvertTritonToTritonGPU
|
||||
|
Reference in New Issue
Block a user