More progress on SCF type conversion
This commit is contained in:
@@ -4,8 +4,8 @@
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
// #include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "../PassDetail.h"
|
||||
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
@@ -222,16 +222,38 @@ void populateTritonPatterns(
|
||||
//
|
||||
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
||||
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
|
||||
|
||||
// Ref: ConvertForOpTypes
|
||||
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newFor = rewriter.replaceOpWithNewOp<scf::ForOp>(
|
||||
op, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(),
|
||||
adaptor.getInitArgs()
|
||||
);
|
||||
// TODO: we need to copy (?) the body of ForOp
|
||||
llvm_unreachable("Not implemented");
|
||||
// newFor.getRegion().takeBody(adaptor.getRegion());
|
||||
SmallVector<Type> newResultTypes;
|
||||
for (Type type : op.getResultTypes()) {
|
||||
Type newType = typeConverter->convertType(type);
|
||||
if (!newType)
|
||||
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
||||
newResultTypes.push_back(newType);
|
||||
}
|
||||
|
||||
auto newOp = cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
|
||||
newOp.getLoopBody().end());
|
||||
|
||||
// Now, update all the types.
|
||||
|
||||
// Convert the type of the entry block of the ForOp's body.
|
||||
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
|
||||
*getTypeConverter()))) {
|
||||
return rewriter.notifyMatchFailure(op, "could not convert body types");
|
||||
}
|
||||
// Change the clone to use the updated operands. We could have cloned with
|
||||
// a BlockAndValueMapping, but this seems a bit more direct.
|
||||
newOp->setOperands(adaptor.getOperands());
|
||||
// Update the result types to the new converted types.
|
||||
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
||||
std::get<0>(t).setType(std::get<1>(t));
|
||||
|
||||
rewriter.replaceOp(op, newOp.getResults());
|
||||
return success();
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -241,8 +263,11 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
||||
|
||||
LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
|
||||
// rewriter.create<scf::YieldOp>(op.getLoc(), adaptor.getOperands());
|
||||
// op.erase();
|
||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(
|
||||
op, adaptor.getResults()
|
||||
op, adaptor.getOperands()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
@@ -274,6 +299,8 @@ public:
|
||||
// add rules
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateTritonPatterns(typeConverter, patterns);
|
||||
// TODO: can we use
|
||||
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
||||
populateSCFPatterns(typeConverter, patterns);
|
||||
|
||||
if(failed(applyPartialConversion(mod, target,
|
||||
|
@@ -50,7 +50,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// materailizations
|
||||
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
|
||||
llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n"
|
||||
<< "in: \n";
|
||||
inputs[0].dyn_cast<BlockArgument>().getOwner()->getParentOp()->getParentOp()->print(llvm::errs());
|
||||
llvm_unreachable("Not implemented");
|
||||
return llvm::None;
|
||||
});
|
||||
|
Reference in New Issue
Block a user