more pass template

This commit is contained in:
Philippe Tillet
2023-01-06 14:26:06 -08:00
parent b16aeb6541
commit 18c7a72973
9 changed files with 92 additions and 34 deletions

View File

@@ -10,6 +10,7 @@ add_mlir_dialect_library(TritonGPUTransforms
Prefetch.cpp
OptimizeLoadConvert.cpp
SinkConversionsFromShared.cpp
DecomposeConversionsToDotOperand.cpp
TritonGPUConversion.cpp
DEPENDS

View File

@@ -19,7 +19,6 @@
#include <memory>
using namespace mlir;
namespace {
#include "TritonGPUCombine.inc"
@@ -483,8 +482,7 @@ public:
return op->getBlock() == cvt->getBlock() &&
!(isa<triton::ReduceOp>(op) &&
!op->getResult(0).getType().isa<RankedTensorType>()) &&
!isa<triton::gpu::ConvertLayoutOp>(op) &&
!isa<scf::YieldOp>(op);
!isa<triton::gpu::ConvertLayoutOp>(op) && !isa<scf::YieldOp>(op);
};
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
if (cvtSlices.empty())

View File

@@ -0,0 +1,36 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
using namespace mlir;
class TritonGPUDecomposeConversionsToDotOperandPass
: public TritonGPUDecomposeConversionsToDotOperandBase<
TritonGPUDecomposeConversionsToDotOperandPass> {
public:
TritonGPUDecomposeConversionsToDotOperandPass() = default;
void runOnOperation() override { return; }
};
std::unique_ptr<Pass>
mlir::createTritonGPUDecomposeConversionsToDotOperandPass() {
return std::make_unique<TritonGPUDecomposeConversionsToDotOperandPass>();
}