[TRITON-MLIR][FRONTEND]fix scf.if to run through layernorm tutorial (#938)

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-12-02 17:45:29 +08:00
committed by GitHub
parent c280ebda1b
commit 521ff9ad74
2 changed files with 56 additions and 3 deletions

View File

@@ -456,10 +456,55 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
}
};
// This is borrowed from ConvertFIfOpTypes in
// SCF/Transforms/StructuralTypeConversions.cpp
class SCFIfPattern : public OpConversionPattern<scf::IfOp> {
public:
using OpConversionPattern<scf::IfOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(scf::IfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO: Generalize this to any type conversion, not just 1:1.
//
// We need to implement something more sophisticated here that tracks which
// types convert to which other types and does the appropriate
// materialization logic.
// For example, it's possible that one result type converts to 0 types and
// another to 2 types, so newResultTypes would at least be the right size to
// not crash in the llvm::zip call below, but then we would set the the
// wrong type on the SSA values! These edge cases are also why we cannot
// safely use the TypeConverter::convertTypes helper here.
SmallVector<Type> newResultTypes;
for (auto type : op.getResultTypes()) {
Type newType = typeConverter->convertType(type);
if (!newType)
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
newResultTypes.push_back(newType);
}
// See comments in the ForOp pattern for why we clone without regions and
// then inline.
scf::IfOp newOp =
cast<scf::IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
newOp.getThenRegion().end());
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
newOp.getElseRegion().end());
// Update the operands and types.
newOp->setOperands(adaptor.getOperands());
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
std::get<0>(t).setType(std::get<1>(t));
rewriter.replaceOp(op, newOp.getResults());
return success();
}
};
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<SCFYieldPattern, SCFForPattern>(typeConverter, context);
patterns.add<SCFYieldPattern, SCFForPattern, SCFIfPattern>(typeConverter,
context);
}
class ConvertTritonToTritonGPU

View File

@@ -359,7 +359,7 @@ class CodeGenerator(ast.NodeVisitor):
cond = cond.to(triton.language.int1, _builder=self.builder)
with enter_sub_region(self) as sr:
liveins, ip_block = sr
liveins_copy = liveins.copy()
then_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(then_block)
self.visit_compound_statement(node.body)
@@ -394,7 +394,15 @@ class CodeGenerator(ast.NodeVisitor):
if then_defs[then_name].type == else_defs[else_name].type:
names.append(then_name)
ret_types.append(then_defs[then_name].type)
# defined in else block but not in then block
# to find in parent scope and yield them
for else_name in else_defs:
if else_name in liveins and else_name not in then_defs:
if else_defs[else_name].type == liveins[else_name].type:
names.append(else_name)
ret_types.append(else_defs[else_name].type)
then_defs[else_name] = liveins_copy[else_name]
self.builder.set_insertion_point_to_end(ip_block)
if then_defs or node.orelse: # with else block