[TRITON-MLIR][FRONTEND]fix scf.if to run through layernorm tutorial (#938)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -456,10 +456,55 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// This is borrowed from ConvertFIfOpTypes in
|
||||||
|
// SCF/Transforms/StructuralTypeConversions.cpp
|
||||||
|
class SCFIfPattern : public OpConversionPattern<scf::IfOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<scf::IfOp>::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(scf::IfOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
// TODO: Generalize this to any type conversion, not just 1:1.
|
||||||
|
//
|
||||||
|
// We need to implement something more sophisticated here that tracks which
|
||||||
|
// types convert to which other types and does the appropriate
|
||||||
|
// materialization logic.
|
||||||
|
// For example, it's possible that one result type converts to 0 types and
|
||||||
|
// another to 2 types, so newResultTypes would at least be the right size to
|
||||||
|
// not crash in the llvm::zip call below, but then we would set the the
|
||||||
|
// wrong type on the SSA values! These edge cases are also why we cannot
|
||||||
|
// safely use the TypeConverter::convertTypes helper here.
|
||||||
|
SmallVector<Type> newResultTypes;
|
||||||
|
for (auto type : op.getResultTypes()) {
|
||||||
|
Type newType = typeConverter->convertType(type);
|
||||||
|
if (!newType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
||||||
|
newResultTypes.push_back(newType);
|
||||||
|
}
|
||||||
|
|
||||||
|
// See comments in the ForOp pattern for why we clone without regions and
|
||||||
|
// then inline.
|
||||||
|
scf::IfOp newOp =
|
||||||
|
cast<scf::IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||||
|
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
|
||||||
|
newOp.getThenRegion().end());
|
||||||
|
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
|
||||||
|
newOp.getElseRegion().end());
|
||||||
|
|
||||||
|
// Update the operands and types.
|
||||||
|
newOp->setOperands(adaptor.getOperands());
|
||||||
|
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
||||||
|
std::get<0>(t).setType(std::get<1>(t));
|
||||||
|
rewriter.replaceOp(op, newOp.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
patterns.add<SCFYieldPattern, SCFForPattern>(typeConverter, context);
|
patterns.add<SCFYieldPattern, SCFForPattern, SCFIfPattern>(typeConverter,
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
class ConvertTritonToTritonGPU
|
class ConvertTritonToTritonGPU
|
||||||
|
@@ -359,7 +359,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||||
with enter_sub_region(self) as sr:
|
with enter_sub_region(self) as sr:
|
||||||
liveins, ip_block = sr
|
liveins, ip_block = sr
|
||||||
|
liveins_copy = liveins.copy()
|
||||||
then_block = self.builder.create_block()
|
then_block = self.builder.create_block()
|
||||||
self.builder.set_insertion_point_to_start(then_block)
|
self.builder.set_insertion_point_to_start(then_block)
|
||||||
self.visit_compound_statement(node.body)
|
self.visit_compound_statement(node.body)
|
||||||
@@ -395,6 +395,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
names.append(then_name)
|
names.append(then_name)
|
||||||
ret_types.append(then_defs[then_name].type)
|
ret_types.append(then_defs[then_name].type)
|
||||||
|
|
||||||
|
# defined in else block but not in then block
|
||||||
|
# to find in parent scope and yield them
|
||||||
|
for else_name in else_defs:
|
||||||
|
if else_name in liveins and else_name not in then_defs:
|
||||||
|
if else_defs[else_name].type == liveins[else_name].type:
|
||||||
|
names.append(else_name)
|
||||||
|
ret_types.append(else_defs[else_name].type)
|
||||||
|
then_defs[else_name] = liveins_copy[else_name]
|
||||||
self.builder.set_insertion_point_to_end(ip_block)
|
self.builder.set_insertion_point_to_end(ip_block)
|
||||||
|
|
||||||
if then_defs or node.orelse: # with else block
|
if then_defs or node.orelse: # with else block
|
||||||
|
Reference in New Issue
Block a user