diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 1a6512648..46ed700ba 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -5,6 +5,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "../PassDetail.h" +#include using namespace mlir; using namespace mlir::triton; @@ -12,7 +13,7 @@ using namespace mlir::triton; namespace { template -class ArithBinaryPattern : public OpConversionPattern { +class ArithGenericPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -48,8 +49,10 @@ public: LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); + auto value = adaptor.getValue().dyn_cast(); + assert(value); rewriter.replaceOpWithNewOp( - op, retType, adaptor.getValue() + op, retType, value.reshape(retType) // This is a hack. We just want to add encoding ); return success(); } @@ -90,40 +93,42 @@ void populateArithmeticPatternsAndLegality( // Rewrite rule // patterns.add(typeConverter, context); patterns.add, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, // NegFOp + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, // NegFOp // Floating point - ArithBinaryPattern, - ArithBinaryPattern, + ArithGenericPattern, + ArithGenericPattern, // MaxMin - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, // Floating point - ArithBinaryPattern, - ArithBinaryPattern, - ArithBinaryPattern, + ArithGenericPattern, + ArithGenericPattern, + ArithGenericPattern, // Cmp ArithCmpPattern, - ArithCmpPattern + ArithCmpPattern, // Cast Ops + ArithGenericPattern, + ArithGenericPattern >(typeConverter, context); } @@ -212,6 +217,46 @@ void populateTritonPatterns( >(typeConverter, context); } +// +// SCF patterns +// +struct SCFForPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newFor = rewriter.replaceOpWithNewOp( + op, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(), + adaptor.getInitArgs() + ); + // TODO: we need to copy (?) the body of ForOp + llvm_unreachable("Not implemented"); + // newFor.getRegion().takeBody(adaptor.getRegion()); + return success(); + } +}; + +struct SCFYieldPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, adaptor.getResults() + ); + return success(); + } +}; + +void populateSCFPatterns( + TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns +) { + MLIRContext *context = patterns.getContext(); + patterns.add(typeConverter, context); +} + class ConvertTritonToTritonGPU : public ConvertTritonToTritonGPUBase { @@ -229,7 +274,7 @@ public: // add rules populateArithmeticPatternsAndLegality(typeConverter, patterns, target); populateTritonPatterns(typeConverter, patterns); - + populateSCFPatterns(typeConverter, patterns); if(failed(applyPartialConversion(mod, target, std::move(patterns)))) diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index c4fc36861..4067156dc 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -2,6 +2,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include +#include using namespace mlir; @@ -45,6 +46,25 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, context, threadTileSize, blockTileSize, order); return RankedTensorType::get(shape, elementType, encoding); }); + + // materailizations + addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) { + // llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n"; + llvm_unreachable("Not implemented"); + return llvm::None; + }); + addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) { + llvm_unreachable("Not implemented"); + return llvm::None; + }); + addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) { + llvm_unreachable("Not implemented"); + // llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n"; + return llvm::None; + }); } // @@ -53,25 +73,15 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, TritonGPUConversionTarget::TritonGPUConversionTarget( MLIRContext &context, TritonGPUTypeConverter &typeConverter) : ConversionTarget(context), typeConverter(typeConverter) { - addLegalDialect(); - // Some ops from SCF are illegal addIllegalOp(); - addDynamicallyLegalDialect([&](Operation *op) { - if (typeConverter.isLegal(op)) - return true; - return false; - }); - - addDynamicallyLegalDialect([&](Operation *op) { - if (typeConverter.isLegal(op)) - return true; - return false; - }); - - addDynamicallyLegalDialect([&](Operation *op) { + addDynamicallyLegalDialect([&](Operation *op) { if (typeConverter.isLegal(op)) return true; return false;