A more general pipeliner

This commit is contained in:
Yan Da
2022-05-25 21:52:51 +08:00
parent 441fd7c3cc
commit 9308e9c90c
3 changed files with 210 additions and 112 deletions

View File

@@ -163,16 +163,16 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
return failure(); return failure();
Value a = adaptor.a(); Value a = adaptor.a();
Value b = adaptor.b(); Value b = adaptor.b();
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) { // if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1); // Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding); // auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a); // a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
} // }
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) { // if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1); // Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding); // auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b); // b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
} // }
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>( auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32() op, retType, a, b, adaptor.c(), adaptor.allowTF32()
); );

View File

@@ -2,6 +2,8 @@
#include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
#include <llvm-6.0/llvm/Support/raw_ostream.h>
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// //
@@ -19,38 +21,43 @@ using namespace mlir;
namespace { namespace {
class LoopPipeliner { class LoopPipeliner {
struct PipelineInfo {
triton::DotOp dotOp;
triton::LoadOp aLoadOp;
triton::LoadOp bLoadOp;
};
/// comments on numStages: /// comments on numStages:
/// [0, numStages-1) are in the prologue /// [0, numStages-1) are in the prologue
/// numStages-1 is appended after the loop body /// numStages-1 is appended after the loop body
int numStages; int numStages;
/// cache forOp we are working on /// cache forOp we are working on
scf::ForOp forOp; scf::ForOp forOp;
/// dot & loads
PipelineInfo info; /// cahce YieldOp for this forOp
scf::YieldOp yieldOp;
/// loads to be pipelined
SetVector<Value> loads;
/// value (in loop) => value at stage N /// value (in loop) => value at stage N
DenseMap<Value, SmallVector<Value>> valueMapping; DenseMap<Value, SmallVector<Value>> valueMapping;
/// Block arguments that loads depend on
DenseSet<BlockArgument> depArgs; DenseSet<BlockArgument> depArgs;
/// Operations (inside the loop body) that loads depend on
DenseSet<Operation*> depOps; DenseSet<Operation*> depOps;
void setValueMapping(Value origin, Value newValue, int stage);
/// collect values that v depends on and are defined inside the loop /// collect values that v depends on and are defined inside the loop
void collectDeps(Value v); void collectDeps(Value v, int stages, DenseSet<Value> &deps);
void setValueMapping(Value origin, Value newValue, int stage);
public: public:
LoopPipeliner(scf::ForOp forOp, int numStages) LoopPipeliner(scf::ForOp forOp, int numStages)
: forOp(forOp), numStages(numStages) {} : forOp(forOp), numStages(numStages) {
// cache yieldOp
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
}
/// Collect loop info. Return success if we can pipeline this loop /// Collect loads to pipeline. Return success if we can pipeline this loop
LogicalResult initialize(); LogicalResult initialize();
/// /// emit pipelined loads (before loop body)
void emitPrologue(); void emitPrologue();
/// create the new ForOp (add new args & insert prefetched ops) /// create the new ForOp (add new args & insert prefetched ops)
@@ -66,71 +73,105 @@ void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) {
valueMapping[origin][stage] = newValue; valueMapping[origin][stage] = newValue;
} }
void LoopPipeliner::collectDeps(Value v) { void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
// Loop-invarant value. skip
if (v.getParentRegion() != &forOp.getLoopBody()) if (v.getParentRegion() != &forOp.getLoopBody())
return; return;
// Since we only need to peel the loop numStages-1 times, don't worry about
// depends that are too far away
if (stages < 0)
return;
if (auto arg = v.dyn_cast<BlockArgument>()) { if (auto arg = v.dyn_cast<BlockArgument>()) {
if (depArgs.contains(arg)) deps.insert(v);
return;
depArgs.insert(arg);
// we also need to rematerialize this arg
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1 // Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
collectDeps(yield->getOperand(arg.getArgNumber() - 1)); collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages-1, deps);
} else { // value } else { // value
Operation *defOp = v.getDefiningOp(); // v might be in deps, but we still need to visit v.
if (depOps.contains(defOp)) // This is because v might depends on value in previous iterations
return; deps.insert(v);
depOps.insert(defOp); for (Value op : v.getDefiningOp()->getOperands())
for (Value op : defOp->getOperands()) collectDeps(op, stages, deps);
collectDeps(op);
} }
} }
/// A load instruction can be pipelined if: /// A load instruction can be pipelined if:
/// - the pointer is a block argument (redefined inside the loop) /// - the load doesn't depend on any other loads (after loop peeling)
/// - the load has only a single use in a dot instruction /// - (?) this load is not a loop-invariant value (we should run LICM before
/// this pass?)
LogicalResult LoopPipeliner::initialize() { LogicalResult LoopPipeliner::initialize() {
Block *loop = forOp.getBody(); Block *loop = forOp.getBody();
// TODO: can we use forOp.walk(...) here? // can we use forOp.walk(...) here?
SmallVector<triton::DotOp, 2> dots; SmallVector<triton::LoadOp, 2> allLoads;
for (Operation &op : *loop) { for (Operation &op : *loop)
if (auto dotOp = dyn_cast<triton::DotOp>(&op)) { if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
dots.push_back(dotOp); allLoads.push_back(loadOp);
}
}
// Don't know what to do if we have more than 1 dots inside the loop // Early stop: no need to continue if there is no load in the loop.
if (dots.size() != 1) if (allLoads.empty())
return failure(); return failure();
triton::DotOp dotOp = dots[0]; // load => values that it depends on
// dot (cvt (load %ptr0)), (cvt (load %ptr1)) DenseMap<Value, DenseSet<Value>> loadDeps;
auto getDefinintLoad = [&](Value v) -> triton::LoadOp { for (triton::LoadOp loadOp : allLoads) {
auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>(); DenseSet<Value> deps;
if (cvt) { for (Value op : loadOp->getOperands())
return cvt.src().getDefiningOp<triton::LoadOp>(); collectDeps(op, numStages - 1, deps);
} loadDeps[loadOp] = deps;
return nullptr;
};
auto aLoad = getDefinintLoad(dotOp.a());
auto bLoad = getDefinintLoad(dotOp.b());
// ptrs must be block args (phi nodes)
if (aLoad && bLoad) {
if (aLoad.ptr().isa<BlockArgument>() && bLoad.ptr().isa<BlockArgument>()) {
info.dotOp = dotOp; info.aLoadOp = aLoad; info.bLoadOp = bLoad;
collectDeps(dotOp.a());
collectDeps(dotOp.b());
return success();
}
} }
// for (triton::LoadOp loadOp : allLoads) {
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() << " values\n";
// for (Value dep : loadDeps[loadOp])
// llvm::errs() << dep << "\n";
// llvm::errs() << "\n";
// }
// Don't pipeline loads that depend on other loads
// (Because if a load depends on another load, this load needs to wait on the
// other load in the prologue, which is against the point of the pipeline
// pass)
for (triton::LoadOp loadOp : allLoads) {
bool isCandiate = true;
for (triton::LoadOp other : allLoads) {
if (loadDeps[loadOp].contains(other)) {
isCandiate = false;
break;
}
}
if (isCandiate)
loads.insert(loadOp);
}
// we have some loads to pipeline
if (!loads.empty()) {
// update depArgs & depOps
for (Value loadOp : loads) {
for (Value dep : loadDeps[loadOp]) {
// TODO: we should record the stage that the value is depended on
if (auto arg = dep.dyn_cast<BlockArgument>())
depArgs.insert(arg);
else
depOps.insert(dep.getDefiningOp());
}
}
return success();
}
// llvm::errs() << allLoads.size() << " loads inside the loop\n"
// << loads.size() << " loads to be pipelined\n";
return failure(); return failure();
} }
void LoopPipeliner::emitPrologue() { void LoopPipeliner::emitPrologue() {
// llvm::errs() << "to pipeline...\n";
// for (Value load : loads)
// llvm::errs() << load << "\n";
// TODO: should we use rewriter here? // TODO: should we use rewriter here?
OpBuilder builder(forOp); OpBuilder builder(forOp);
for (BlockArgument &arg : forOp.getRegionIterArgs()) { for (BlockArgument &arg : forOp.getRegionIterArgs()) {
@@ -139,7 +180,6 @@ void LoopPipeliner::emitPrologue() {
} }
// prologue from [0, numStage-1) // prologue from [0, numStage-1)
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
Value iv = forOp.getLowerBound(); Value iv = forOp.getLowerBound();
for (int stage = 0; stage < numStages - 1; ++stage) { for (int stage = 0; stage < numStages - 1; ++stage) {
// special handling for induction variable as the increment is implicit // special handling for induction variable as the increment is implicit
@@ -153,12 +193,30 @@ void LoopPipeliner::emitPrologue() {
// rematerialize peeled values // rematerialize peeled values
SmallVector<Operation*> orderedDeps; SmallVector<Operation*> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op)) if (depOps.contains(&op))
orderedDeps.push_back(&op); orderedDeps.push_back(&op);
assert(depOps.size() == orderedDeps.size() && "depOps contains invalid values"); else if (loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
"depOps contains invalid values");
for (Operation *op : orderedDeps) { for (Operation *op : orderedDeps) {
Operation *newOp = builder.clone(*op); Operation *newOp = nullptr;
if (loads.contains(op->getResult(0))) {
// load => copy async
// TODO: check if the hardware supports copyasync
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
newOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), op->getResult(0).getType(),
loadOp.ptr(), loadOp.mask(), loadOp.other(),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
);
} else
llvm_unreachable("This should be LoadOp");
} else
newOp = builder.clone(*op);
// llvm::errs() << "cloning " << *op << "\n";
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) { for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
auto it = valueMapping.find(op->getOperand(opIdx)); auto it = valueMapping.find(op->getOperand(opIdx));
if (it != valueMapping.end()) { if (it != valueMapping.end()) {
@@ -168,11 +226,13 @@ void LoopPipeliner::emitPrologue() {
} // else, op at opIdx is a loop-invariant value } // else, op at opIdx is a loop-invariant value
} }
// TODO: if this is a load, we need to update the mask
// update mapping of results // update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
setValueMapping(op->getResult(dstIdx), newOp->getResult(dstIdx), stage); setValueMapping(op->getResult(dstIdx), newOp->getResult(dstIdx), stage);
// update mapping for loop-carried values (args) // update mapping for loop-carried values (args)
for (OpOperand &operand : yield->getOpOperands()) { for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx)) if (operand.get() == op->getResult(dstIdx))
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(dstIdx), stage + 1); newOp->getResult(dstIdx), stage + 1);
@@ -187,7 +247,8 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// order of new args: // order of new args:
// (original args), // (original args),
// (a at stage[0, numStages-1)), (b at stage[0, numStages-1)) // for each load result x:
// (x at stage[0, numStages-1))
// (depArgs at stage numStages-1) // (depArgs at stage numStages-1)
// (iv at stage numStages-1) // (iv at stage numStages-1)
SmallVector<Value> newLoopArgs; SmallVector<Value> newLoopArgs;
@@ -196,54 +257,64 @@ scf::ForOp LoopPipeliner::createNewForOp() {
DenseMap<BlockArgument, size_t> depArgsIdx; DenseMap<BlockArgument, size_t> depArgsIdx;
for (auto v : forOp.getIterOperands()) for (auto v : forOp.getIterOperands())
newLoopArgs.push_back(v); newLoopArgs.push_back(v);
size_t aArgIdx = newLoopArgs.size();
for (int i = 0; i < numStages - 1; ++i) size_t loadIdx = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[info.dotOp.a()][i]); for (Value loadOp : loads)
size_t bArgIdx = newLoopArgs.size(); for (int i = 0; i < numStages - 1; ++i)
for (int i = 0; i < numStages - 1; ++i) newLoopArgs.push_back(valueMapping[loadOp][i]);
newLoopArgs.push_back(valueMapping[info.dotOp.b()][i]);
size_t depArgsBeginIdx = newLoopArgs.size(); size_t depArgsBeginIdx = newLoopArgs.size();
for (BlockArgument depArg : depArgs) { for (BlockArgument depArg : depArgs) {
depArgsIdx[depArg] = newLoopArgs.size(); depArgsIdx[depArg] = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[depArg][numStages-1]); newLoopArgs.push_back(valueMapping[depArg][numStages-1]);
} }
size_t nextIVIdx = newLoopArgs.size(); size_t nextIVIdx = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]); newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]);
for (size_t i = 0; i < newLoopArgs.size(); ++i) for (size_t i = 0; i < newLoopArgs.size(); ++i)
assert(newLoopArgs[i]); assert(newLoopArgs[i]);
// signature of the new ForOp // llvm::errs() << "mapped load is:\n" << newLoopArgs[loadIdx] << "\n\n";
// 1. signature of the new ForOp
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(), auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
forOp.getLowerBound(), forOp.getLowerBound(),
forOp.getUpperBound(), forOp.getUpperBound(),
forOp.getStep(), forOp.getStep(),
newLoopArgs); newLoopArgs);
// body of the new ForOp // 2. body of the new ForOp
builder.setInsertionPointToStart(newForOp.getBody()); builder.setInsertionPointToStart(newForOp.getBody());
BlockAndValueMapping mapping; BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
// mapping.map(info.dotOp.a(), newForOp.getRegionIterArgs()[aArgIdx]);
// mapping.map(info.dotOp.b(), newForOp.getRegionIterArgs()[bArgIdx]);
for (Operation &op : forOp.getBody()->without_terminator()) { for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping); Operation *newOp = builder.clone(op, mapping);
// update mapping of results // update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx)); mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
// TODO: why doesn't mapping work?
if (&op == info.dotOp.getOperation()) {
newOp->setOperand(0, newForOp.getRegionIterArgs()[aArgIdx]);
newOp->setOperand(1, newForOp.getRegionIterArgs()[bArgIdx]);
}
} }
// prefetch next iteration
// 3. replace loads with args
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
mapping.lookup(load).replaceAllUsesWith(
newForOp.getRegionIterArgs()[loadIdx+idx]);
}
// 4. prefetch the next iteration
SmallVector<Operation*> orderedDeps; SmallVector<Operation*> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op)) if (depOps.contains(&op))
orderedDeps.push_back(&op); orderedDeps.push_back(&op);
assert(depOps.size() == orderedDeps.size() && "depOps contains invalid values"); else if (loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
"depOps contains invalid values");
BlockAndValueMapping nextMapping; BlockAndValueMapping nextMapping;
DenseMap<BlockArgument, Value> depArgsMapping; DenseMap<BlockArgument, Value> depArgsMapping;
size_t argIdx = 0; size_t argIdx = 0;
@@ -259,8 +330,9 @@ scf::ForOp LoopPipeliner::createNewForOp() {
nextIV.getLoc(), arith::CmpIPredicate::slt, nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound()); nextIV, newForOp.getUpperBound());
for (Operation *op : orderedDeps) { for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// update loading mask // update loading mask
if (op == info.aLoadOp.getOperation() || op == info.bLoadOp.getOperation()) { if (loads.contains(op->getResult(0))) {
auto loadOp = llvm::cast<triton::LoadOp>(op); auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask(); Value mask = loadOp.mask();
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(), Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
@@ -272,8 +344,18 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// if mask is defined outside the loop, don't update the map more than once // if mask is defined outside the loop, don't update the map more than once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask); nextMapping.map(mask, newMask);
// TODO: more elegant way to do this?
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), op->getResult(0).getType(),
nextMapping.lookupOrDefault(loadOp.ptr()),
nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
);
} }
Operation *nextOp = builder.clone(*op, nextMapping); else
nextOp = builder.clone(*op, nextMapping);
// llvm::errs() << "epilogue cloning...: " << *op << "\n";
// update mapping of results // update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
@@ -294,12 +376,22 @@ scf::ForOp LoopPipeliner::createNewForOp() {
SmallVector<Value> yieldValues; SmallVector<Value> yieldValues;
for (Value v : forOp.getBody()->getTerminator()->getOperands()) for (Value v : forOp.getBody()->getTerminator()->getOperands())
yieldValues.push_back(mapping.lookup(v)); yieldValues.push_back(mapping.lookup(v));
for (int i = 1; i < numStages - 1; ++i) // for (int i = 1; i < numStages - 1; ++i)
yieldValues.push_back(newForOp.getRegionIterArgs()[aArgIdx + i]); // yieldValues.push_back(newForOp.getRegionIterArgs()[aArgIdx + i]);
yieldValues.push_back(nextMapping.lookup(info.dotOp.a())); // yieldValues.push_back(nextMapping.lookup(info.dotOp.a()));
for (int i = 1; i < numStages - 1; ++i) // for (int i = 1; i < numStages - 1; ++i)
yieldValues.push_back(newForOp.getRegionIterArgs()[bArgIdx + i]); // yieldValues.push_back(newForOp.getRegionIterArgs()[bArgIdx + i]);
yieldValues.push_back(nextMapping.lookup(info.dotOp.b())); // yieldValues.push_back(nextMapping.lookup(info.dotOp.b()));
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
for (int stage = 1; stage < numStages - 1; ++stage) {
yieldValues.push_back(newForOp.getRegionIterArgs()[
loadIdx + idx*(numStages-1) + stage-1
]);
}
yieldValues.push_back(nextMapping.lookup(load));
}
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i])); yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
yieldValues.push_back(nextIV); yieldValues.push_back(nextIV);
@@ -328,10 +420,16 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
if (pipeliner.initialize().failed()) if (pipeliner.initialize().failed())
return; return;
// llvm::errs() << "find a loop to pipeline...\n";
pipeliner.emitPrologue(); pipeliner.emitPrologue();
// llvm::errs() << "\nprologue emitted\n"
// << *forOp->getParentOp();
scf::ForOp newForOp = pipeliner.createNewForOp(); scf::ForOp newForOp = pipeliner.createNewForOp();
// llvm::errs() << "new for created:\n" << newForOp << "\n"
// << "inside:\n" << *newForOp->getParentOp() << "\n";
// replace the original loop // replace the original loop
for (unsigned i = 0; i < forOp->getNumResults(); ++i) for (unsigned i = 0; i < forOp->getNumResults(); ++i)
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));

View File

@@ -92,17 +92,17 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
}); });
// We have requirements for the data layouts // // We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool { // addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding(); // Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding(); // Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() && // if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) // bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return true; // return true;
// TODO: we should delete this // // TODO: we should delete this
if (this->typeConverter.isLegal(dotOp)) // if (this->typeConverter.isLegal(dotOp))
return true; // return true;
return false; // return false;
}); // });
} }