More on the pipeline pass
This commit is contained in:
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::FuncOp"> {
|
def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
||||||
let summary = "pipeline";
|
let summary = "pipeline";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@@ -7,10 +7,102 @@ using namespace mlir;
|
|||||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
class LoopPipeliner {
|
||||||
|
struct PipelineInfo {
|
||||||
|
triton::DotOp dotOp;
|
||||||
|
triton::LoadOp aLoadOp;
|
||||||
|
triton::LoadOp bLoadOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
int numStages;
|
||||||
|
/// cache forOp we are working on
|
||||||
|
scf::ForOp forOp;
|
||||||
|
/// dot & loads
|
||||||
|
PipelineInfo info;
|
||||||
|
/// value (in loop) => value at stage N
|
||||||
|
DenseMap<Value, SmallVector<Value>> valueMapping;
|
||||||
|
|
||||||
|
void setStageValueMapping(Value origin, Value prefetched, int idx);
|
||||||
|
public:
|
||||||
|
LoopPipeliner(scf::ForOp forOp, int numStages)
|
||||||
|
: forOp(forOp), numStages(numStages) {}
|
||||||
|
|
||||||
|
/// Collect loop info. Return success if we can pipeline this loop
|
||||||
|
LogicalResult initialize();
|
||||||
|
|
||||||
|
///
|
||||||
|
void emitPrologue();
|
||||||
|
|
||||||
|
friend class PipelinePass;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A load instruction can be pipelined if:
|
||||||
|
/// - the pointer is a block argument (redefined inside the loop)
|
||||||
|
/// - the load has only a single use in a dot instruction
|
||||||
|
LogicalResult LoopPipeliner::initialize() {
|
||||||
|
Region &bodyRegion = forOp.getLoopBody();
|
||||||
|
assert(bodyRegion.hasOneBlock());
|
||||||
|
Block &loop = bodyRegion.front();
|
||||||
|
|
||||||
|
// TODO: can we use forOp.walk(...) here?
|
||||||
|
SmallVector<triton::DotOp, 2> dots;
|
||||||
|
for (Operation &op : loop) {
|
||||||
|
if (auto dotOp = dyn_cast<triton::DotOp>(&op)) {
|
||||||
|
dots.push_back(dotOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't know what to do if we have more than 1 dots inside the loop
|
||||||
|
if (dots.size() != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
triton::DotOp dotOp = dots[0];
|
||||||
|
// dot (cvt (load %ptr0)), (cvt (load %ptr1))
|
||||||
|
auto getDefinintLoad = [&](Value v) -> triton::LoadOp {
|
||||||
|
auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>();
|
||||||
|
if (cvt) {
|
||||||
|
return cvt.src().getDefiningOp<triton::LoadOp>();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
};
|
||||||
|
auto aLoad = getDefinintLoad(dotOp.a());
|
||||||
|
auto bLoad = getDefinintLoad(dotOp.b());
|
||||||
|
|
||||||
|
// ptrs must be block args (phi nodes)
|
||||||
|
if (aLoad && bLoad) {
|
||||||
|
if (aLoad.ptr().isa<BlockArgument>() && bLoad.ptr().isa<BlockArgument>()) {
|
||||||
|
info.dotOp = dotOp; info.aLoadOp = aLoad; info.bLoadOp = bLoad;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
void LoopPipeliner::emitPrologue() {
|
||||||
|
OpBuilder builder(forOp);
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
|
||||||
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
getOperation()->walk([&](scf::ForOp forOp) {
|
// TODO: collect numStages from ModuleOp
|
||||||
|
int numStages = 2;
|
||||||
|
|
||||||
|
if (numStages <= 1)
|
||||||
|
return;
|
||||||
|
|
||||||
|
getOperation()->walk([&](scf::ForOp forOp) -> void {
|
||||||
|
LoopPipeliner pipeliner(forOp, numStages);
|
||||||
|
|
||||||
|
if (pipeliner.initialize().failed())
|
||||||
|
return;
|
||||||
|
|
||||||
|
llvm::errs() << "candidate for pipelining: " << pipeliner.info.dotOp
|
||||||
|
<< "\n";
|
||||||
|
|
||||||
|
// pipeliner.emitPrologue();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@@ -16,6 +16,7 @@
|
|||||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/Triton/IR/Types.h"
|
#include "triton/Dialect/Triton/IR/Types.h"
|
||||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||||
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/LegacyPassManager.h"
|
#include "llvm/IR/LegacyPassManager.h"
|
||||||
@@ -1340,6 +1341,9 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) {
|
.def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::triton::createConvertTritonToTritonGPUPass());
|
self.addPass(mlir::triton::createConvertTritonToTritonGPUPass());
|
||||||
})
|
})
|
||||||
|
.def("add_tritongpu_pipeline_pass", [](mlir::PassManager &self) {
|
||||||
|
self.addPass(mlir::createTritonGPUPipelinePass());
|
||||||
|
})
|
||||||
;
|
;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1314,6 +1314,7 @@ class JITFunction:
|
|||||||
pm.add_triton_combine_pass()
|
pm.add_triton_combine_pass()
|
||||||
pm.add_canonicalizer_pass()
|
pm.add_canonicalizer_pass()
|
||||||
pm.add_convert_triton_to_tritongpu_pass()
|
pm.add_convert_triton_to_tritongpu_pass()
|
||||||
|
pm.add_tritongpu_pipeline_pass()
|
||||||
pm.run(mod)
|
pm.run(mod)
|
||||||
return mod
|
return mod
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user