From 521ff9ad74175b09fc09de7ed4cc0571457f57a8 Mon Sep 17 00:00:00 2001 From: donproc Date: Fri, 2 Dec 2022 17:45:29 +0800 Subject: [PATCH] [TRITON-MLIR][FRONTEND]fix scf.if to run through layernorm tutorial (#938) Co-authored-by: dongdongl --- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 47 ++++++++++++++++++- python/triton/compiler.py | 12 ++++- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index b9fd481a6..563858362 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -456,10 +456,55 @@ struct SCFYieldPattern : public OpConversionPattern { } }; +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFIfPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: Generalize this to any type conversion, not just 1:1. + // + // We need to implement something more sophisticated here that tracks which + // types convert to which other types and does the appropriate + // materialization logic. + // For example, it's possible that one result type converts to 0 types and + // another to 2 types, so newResultTypes would at least be the right size to + // not crash in the llvm::zip call below, but then we would set the the + // wrong type on the SSA values! These edge cases are also why we cannot + // safely use the TypeConverter::convertTypes helper here. + SmallVector newResultTypes; + for (auto type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + + // See comments in the ForOp pattern for why we clone without regions and + // then inline. + scf::IfOp newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), + newOp.getThenRegion().end()); + rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), + newOp.getElseRegion().end()); + + // Update the operands and types. + newOp->setOperands(adaptor.getOperands()); + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add(typeConverter, context); + patterns.add(typeConverter, + context); } class ConvertTritonToTritonGPU diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 1df83d4fd..54687170e 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -359,7 +359,7 @@ class CodeGenerator(ast.NodeVisitor): cond = cond.to(triton.language.int1, _builder=self.builder) with enter_sub_region(self) as sr: liveins, ip_block = sr - + liveins_copy = liveins.copy() then_block = self.builder.create_block() self.builder.set_insertion_point_to_start(then_block) self.visit_compound_statement(node.body) @@ -394,7 +394,15 @@ class CodeGenerator(ast.NodeVisitor): if then_defs[then_name].type == else_defs[else_name].type: names.append(then_name) ret_types.append(then_defs[then_name].type) - + + # defined in else block but not in then block + # to find in parent scope and yield them + for else_name in else_defs: + if else_name in liveins and else_name not in then_defs: + if else_defs[else_name].type == liveins[else_name].type: + names.append(else_name) + ret_types.append(else_defs[else_name].type) + then_defs[else_name] = liveins_copy[else_name] self.builder.set_insertion_point_to_end(ip_block) if then_defs or node.orelse: # with else block