More progress on SCF type conversion
This commit is contained in:
@@ -4,8 +4,8 @@
|
|||||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
// #include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
@@ -222,16 +222,38 @@ void populateTritonPatterns(
|
|||||||
//
|
//
|
||||||
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
||||||
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
|
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
|
||||||
|
// Ref: ConvertForOpTypes
|
||||||
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto newFor = rewriter.replaceOpWithNewOp<scf::ForOp>(
|
SmallVector<Type> newResultTypes;
|
||||||
op, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(),
|
for (Type type : op.getResultTypes()) {
|
||||||
adaptor.getInitArgs()
|
Type newType = typeConverter->convertType(type);
|
||||||
);
|
if (!newType)
|
||||||
// TODO: we need to copy (?) the body of ForOp
|
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
||||||
llvm_unreachable("Not implemented");
|
newResultTypes.push_back(newType);
|
||||||
// newFor.getRegion().takeBody(adaptor.getRegion());
|
}
|
||||||
|
|
||||||
|
auto newOp = cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||||
|
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
|
||||||
|
newOp.getLoopBody().end());
|
||||||
|
|
||||||
|
// Now, update all the types.
|
||||||
|
|
||||||
|
// Convert the type of the entry block of the ForOp's body.
|
||||||
|
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
|
||||||
|
*getTypeConverter()))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "could not convert body types");
|
||||||
|
}
|
||||||
|
// Change the clone to use the updated operands. We could have cloned with
|
||||||
|
// a BlockAndValueMapping, but this seems a bit more direct.
|
||||||
|
newOp->setOperands(adaptor.getOperands());
|
||||||
|
// Update the result types to the new converted types.
|
||||||
|
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
||||||
|
std::get<0>(t).setType(std::get<1>(t));
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, newOp.getResults());
|
||||||
|
return success();
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -241,8 +263,11 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
|||||||
|
|
||||||
LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
|
LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
// rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
|
||||||
|
// rewriter.create<scf::YieldOp>(op.getLoc(), adaptor.getOperands());
|
||||||
|
// op.erase();
|
||||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(
|
rewriter.replaceOpWithNewOp<scf::YieldOp>(
|
||||||
op, adaptor.getResults()
|
op, adaptor.getOperands()
|
||||||
);
|
);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -274,6 +299,8 @@ public:
|
|||||||
// add rules
|
// add rules
|
||||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||||
populateTritonPatterns(typeConverter, patterns);
|
populateTritonPatterns(typeConverter, patterns);
|
||||||
|
// TODO: can we use
|
||||||
|
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
||||||
populateSCFPatterns(typeConverter, patterns);
|
populateSCFPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
if(failed(applyPartialConversion(mod, target,
|
if(failed(applyPartialConversion(mod, target,
|
||||||
|
@@ -50,7 +50,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
// materailizations
|
// materailizations
|
||||||
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||||
ValueRange inputs, Location loc) {
|
ValueRange inputs, Location loc) {
|
||||||
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
|
llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n"
|
||||||
|
<< "in: \n";
|
||||||
|
inputs[0].dyn_cast<BlockArgument>().getOwner()->getParentOp()->getParentOp()->print(llvm::errs());
|
||||||
llvm_unreachable("Not implemented");
|
llvm_unreachable("Not implemented");
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
});
|
});
|
||||||
|
Reference in New Issue
Block a user