More progress on SCF type conversion

This commit is contained in:
Yan Da
2022-05-05 20:56:55 +08:00
parent 26c59e4718
commit 0c5319eed9
2 changed files with 40 additions and 11 deletions

View File

@@ -4,8 +4,8 @@
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
// #include "mlir/IR/BlockAndValueMapping.h"
#include "../PassDetail.h"
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
using namespace mlir;
using namespace mlir::triton;
@@ -222,16 +222,38 @@ void populateTritonPatterns(
//
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
// Ref: ConvertForOpTypes
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newFor = rewriter.replaceOpWithNewOp<scf::ForOp>(
op, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(),
adaptor.getInitArgs()
);
// TODO: we need to copy (?) the body of ForOp
llvm_unreachable("Not implemented");
// newFor.getRegion().takeBody(adaptor.getRegion());
SmallVector<Type> newResultTypes;
for (Type type : op.getResultTypes()) {
Type newType = typeConverter->convertType(type);
if (!newType)
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
newResultTypes.push_back(newType);
}
auto newOp = cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
newOp.getLoopBody().end());
// Now, update all the types.
// Convert the type of the entry block of the ForOp's body.
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
*getTypeConverter()))) {
return rewriter.notifyMatchFailure(op, "could not convert body types");
}
// Change the clone to use the updated operands. We could have cloned with
// a BlockAndValueMapping, but this seems a bit more direct.
newOp->setOperands(adaptor.getOperands());
// Update the result types to the new converted types.
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
std::get<0>(t).setType(std::get<1>(t));
rewriter.replaceOp(op, newOp.getResults());
return success();
return success();
}
};
@@ -241,8 +263,11 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
// rewriter.create<scf::YieldOp>(op.getLoc(), adaptor.getOperands());
// op.erase();
rewriter.replaceOpWithNewOp<scf::YieldOp>(
op, adaptor.getResults()
op, adaptor.getOperands()
);
return success();
}
@@ -274,6 +299,8 @@ public:
// add rules
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns);
// TODO: can we use
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
populateSCFPatterns(typeConverter, patterns);
if(failed(applyPartialConversion(mod, target,

View File

@@ -50,7 +50,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// materailizations
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n"
<< "in: \n";
inputs[0].dyn_cast<BlockArgument>().getOwner()->getParentOp()->getParentOp()->print(llvm::errs());
llvm_unreachable("Not implemented");
return llvm::None;
});