[TRITON-MLIR][FRONTEND]fix scf.if to run through layernorm tutorial (#938)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -456,10 +456,55 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// This is borrowed from ConvertFIfOpTypes in
|
||||
// SCF/Transforms/StructuralTypeConversions.cpp
|
||||
class SCFIfPattern : public OpConversionPattern<scf::IfOp> {
|
||||
public:
|
||||
using OpConversionPattern<scf::IfOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::IfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// TODO: Generalize this to any type conversion, not just 1:1.
|
||||
//
|
||||
// We need to implement something more sophisticated here that tracks which
|
||||
// types convert to which other types and does the appropriate
|
||||
// materialization logic.
|
||||
// For example, it's possible that one result type converts to 0 types and
|
||||
// another to 2 types, so newResultTypes would at least be the right size to
|
||||
// not crash in the llvm::zip call below, but then we would set the the
|
||||
// wrong type on the SSA values! These edge cases are also why we cannot
|
||||
// safely use the TypeConverter::convertTypes helper here.
|
||||
SmallVector<Type> newResultTypes;
|
||||
for (auto type : op.getResultTypes()) {
|
||||
Type newType = typeConverter->convertType(type);
|
||||
if (!newType)
|
||||
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
||||
newResultTypes.push_back(newType);
|
||||
}
|
||||
|
||||
// See comments in the ForOp pattern for why we clone without regions and
|
||||
// then inline.
|
||||
scf::IfOp newOp =
|
||||
cast<scf::IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
|
||||
newOp.getThenRegion().end());
|
||||
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
|
||||
newOp.getElseRegion().end());
|
||||
|
||||
// Update the operands and types.
|
||||
newOp->setOperands(adaptor.getOperands());
|
||||
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
||||
std::get<0>(t).setType(std::get<1>(t));
|
||||
rewriter.replaceOp(op, newOp.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<SCFYieldPattern, SCFForPattern>(typeConverter, context);
|
||||
patterns.add<SCFYieldPattern, SCFForPattern, SCFIfPattern>(typeConverter,
|
||||
context);
|
||||
}
|
||||
|
||||
class ConvertTritonToTritonGPU
|
||||
|
@@ -359,7 +359,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, ip_block = sr
|
||||
|
||||
liveins_copy = liveins.copy()
|
||||
then_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(then_block)
|
||||
self.visit_compound_statement(node.body)
|
||||
@@ -394,7 +394,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if then_defs[then_name].type == else_defs[else_name].type:
|
||||
names.append(then_name)
|
||||
ret_types.append(then_defs[then_name].type)
|
||||
|
||||
|
||||
# defined in else block but not in then block
|
||||
# to find in parent scope and yield them
|
||||
for else_name in else_defs:
|
||||
if else_name in liveins and else_name not in then_defs:
|
||||
if else_defs[else_name].type == liveins[else_name].type:
|
||||
names.append(else_name)
|
||||
ret_types.append(else_defs[else_name].type)
|
||||
then_defs[else_name] = liveins_copy[else_name]
|
||||
self.builder.set_insertion_point_to_end(ip_block)
|
||||
|
||||
if then_defs or node.orelse: # with else block
|
||||
|
Reference in New Issue
Block a user