diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index dc720a0d0..2171fe058 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -163,16 +163,16 @@ struct TritonDotPattern : public OpConversionPattern { return failure(); Value a = adaptor.a(); Value b = adaptor.b(); - if (!aEncoding.isa()) { - Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1); - auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding); - a = rewriter.create(a.getLoc(), dstType, a); - } - if (!bEncoding.isa()) { - Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1); - auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding); - b = rewriter.create(b.getLoc(), dstType, b); - } + // if (!aEncoding.isa()) { + // Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1); + // auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding); + // a = rewriter.create(a.getLoc(), dstType, a); + // } + // if (!bEncoding.isa()) { + // Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1); + // auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding); + // b = rewriter.create(b.getLoc(), dstType, b); + // } auto newDot = rewriter.replaceOpWithNewOp( op, retType, a, b, adaptor.c(), adaptor.allowTF32() ); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 33217c200..5577b0742 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -2,6 +2,8 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "mlir/IR/BlockAndValueMapping.h" +#include +#include //===----------------------------------------------------------------------===// // @@ -19,38 +21,43 @@ using namespace mlir; namespace { class LoopPipeliner { - struct PipelineInfo { - triton::DotOp dotOp; - triton::LoadOp aLoadOp; - triton::LoadOp bLoadOp; - }; - /// comments on numStages: /// [0, numStages-1) are in the prologue /// numStages-1 is appended after the loop body int numStages; + /// cache forOp we are working on scf::ForOp forOp; - /// dot & loads - PipelineInfo info; + + /// cahce YieldOp for this forOp + scf::YieldOp yieldOp; + + /// loads to be pipelined + SetVector loads; + /// value (in loop) => value at stage N DenseMap> valueMapping; + /// Block arguments that loads depend on DenseSet depArgs; + /// Operations (inside the loop body) that loads depend on DenseSet depOps; - void setValueMapping(Value origin, Value newValue, int stage); - /// collect values that v depends on and are defined inside the loop - void collectDeps(Value v); + void collectDeps(Value v, int stages, DenseSet &deps); + + void setValueMapping(Value origin, Value newValue, int stage); public: LoopPipeliner(scf::ForOp forOp, int numStages) - : forOp(forOp), numStages(numStages) {} + : forOp(forOp), numStages(numStages) { + // cache yieldOp + yieldOp = cast(forOp.getBody()->getTerminator()); + } - /// Collect loop info. Return success if we can pipeline this loop + /// Collect loads to pipeline. Return success if we can pipeline this loop LogicalResult initialize(); - /// + /// emit pipelined loads (before loop body) void emitPrologue(); /// create the new ForOp (add new args & insert prefetched ops) @@ -66,71 +73,105 @@ void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) { valueMapping[origin][stage] = newValue; } -void LoopPipeliner::collectDeps(Value v) { +void LoopPipeliner::collectDeps(Value v, int stages, DenseSet &deps) { + // Loop-invarant value. skip if (v.getParentRegion() != &forOp.getLoopBody()) return; + + // Since we only need to peel the loop numStages-1 times, don't worry about + // depends that are too far away + if (stages < 0) + return; + if (auto arg = v.dyn_cast()) { - if (depArgs.contains(arg)) - return; - depArgs.insert(arg); - // we also need to rematerialize this arg - auto yield = cast(forOp.getBody()->getTerminator()); + deps.insert(v); // Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1 - collectDeps(yield->getOperand(arg.getArgNumber() - 1)); + collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages-1, deps); } else { // value - Operation *defOp = v.getDefiningOp(); - if (depOps.contains(defOp)) - return; - depOps.insert(defOp); - for (Value op : defOp->getOperands()) - collectDeps(op); + // v might be in deps, but we still need to visit v. + // This is because v might depends on value in previous iterations + deps.insert(v); + for (Value op : v.getDefiningOp()->getOperands()) + collectDeps(op, stages, deps); } } /// A load instruction can be pipelined if: -/// - the pointer is a block argument (redefined inside the loop) -/// - the load has only a single use in a dot instruction +/// - the load doesn't depend on any other loads (after loop peeling) +/// - (?) this load is not a loop-invariant value (we should run LICM before +/// this pass?) LogicalResult LoopPipeliner::initialize() { Block *loop = forOp.getBody(); - // TODO: can we use forOp.walk(...) here? - SmallVector dots; - for (Operation &op : *loop) { - if (auto dotOp = dyn_cast(&op)) { - dots.push_back(dotOp); - } - } + // can we use forOp.walk(...) here? + SmallVector allLoads; + for (Operation &op : *loop) + if (auto loadOp = dyn_cast(&op)) + allLoads.push_back(loadOp); - // Don't know what to do if we have more than 1 dots inside the loop - if (dots.size() != 1) + // Early stop: no need to continue if there is no load in the loop. + if (allLoads.empty()) return failure(); - triton::DotOp dotOp = dots[0]; - // dot (cvt (load %ptr0)), (cvt (load %ptr1)) - auto getDefinintLoad = [&](Value v) -> triton::LoadOp { - auto cvt = v.getDefiningOp(); - if (cvt) { - return cvt.src().getDefiningOp(); - } - return nullptr; - }; - auto aLoad = getDefinintLoad(dotOp.a()); - auto bLoad = getDefinintLoad(dotOp.b()); - - // ptrs must be block args (phi nodes) - if (aLoad && bLoad) { - if (aLoad.ptr().isa() && bLoad.ptr().isa()) { - info.dotOp = dotOp; info.aLoadOp = aLoad; info.bLoadOp = bLoad; - collectDeps(dotOp.a()); - collectDeps(dotOp.b()); - return success(); - } + // load => values that it depends on + DenseMap> loadDeps; + for (triton::LoadOp loadOp : allLoads) { + DenseSet deps; + for (Value op : loadOp->getOperands()) + collectDeps(op, numStages - 1, deps); + loadDeps[loadOp] = deps; } + // for (triton::LoadOp loadOp : allLoads) { + // llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() << " values\n"; + // for (Value dep : loadDeps[loadOp]) + // llvm::errs() << dep << "\n"; + // llvm::errs() << "\n"; + // } + + // Don't pipeline loads that depend on other loads + // (Because if a load depends on another load, this load needs to wait on the + // other load in the prologue, which is against the point of the pipeline + // pass) + for (triton::LoadOp loadOp : allLoads) { + bool isCandiate = true; + for (triton::LoadOp other : allLoads) { + if (loadDeps[loadOp].contains(other)) { + isCandiate = false; + break; + } + } + if (isCandiate) + loads.insert(loadOp); + } + + + // we have some loads to pipeline + if (!loads.empty()) { + // update depArgs & depOps + for (Value loadOp : loads) { + for (Value dep : loadDeps[loadOp]) { + // TODO: we should record the stage that the value is depended on + if (auto arg = dep.dyn_cast()) + depArgs.insert(arg); + else + depOps.insert(dep.getDefiningOp()); + } + } + return success(); + } + + // llvm::errs() << allLoads.size() << " loads inside the loop\n" + // << loads.size() << " loads to be pipelined\n"; + return failure(); } void LoopPipeliner::emitPrologue() { + // llvm::errs() << "to pipeline...\n"; + // for (Value load : loads) + // llvm::errs() << load << "\n"; + // TODO: should we use rewriter here? OpBuilder builder(forOp); for (BlockArgument &arg : forOp.getRegionIterArgs()) { @@ -139,7 +180,6 @@ void LoopPipeliner::emitPrologue() { } // prologue from [0, numStage-1) - auto yield = cast(forOp.getBody()->getTerminator()); Value iv = forOp.getLowerBound(); for (int stage = 0; stage < numStages - 1; ++stage) { // special handling for induction variable as the increment is implicit @@ -153,12 +193,30 @@ void LoopPipeliner::emitPrologue() { // rematerialize peeled values SmallVector orderedDeps; - for (Operation &op : forOp.getLoopBody().front()) + for (Operation &op : forOp.getLoopBody().front()) { if (depOps.contains(&op)) orderedDeps.push_back(&op); - assert(depOps.size() == orderedDeps.size() && "depOps contains invalid values"); + else if (loads.contains(op.getResult(0))) + orderedDeps.push_back(&op); + } + assert(depOps.size() + loads.size() == orderedDeps.size() && + "depOps contains invalid values"); for (Operation *op : orderedDeps) { - Operation *newOp = builder.clone(*op); + Operation *newOp = nullptr; + if (loads.contains(op->getResult(0))) { + // load => copy async + // TODO: check if the hardware supports copyasync + if (auto loadOp = llvm::dyn_cast(op)) { + newOp = builder.create( + op->getLoc(), op->getResult(0).getType(), + loadOp.ptr(), loadOp.mask(), loadOp.other(), + loadOp.cache(), loadOp.evict(), loadOp.isVolatile() + ); + } else + llvm_unreachable("This should be LoadOp"); + } else + newOp = builder.clone(*op); + // llvm::errs() << "cloning " << *op << "\n"; for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) { auto it = valueMapping.find(op->getOperand(opIdx)); if (it != valueMapping.end()) { @@ -168,11 +226,13 @@ void LoopPipeliner::emitPrologue() { } // else, op at opIdx is a loop-invariant value } + // TODO: if this is a load, we need to update the mask + // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { setValueMapping(op->getResult(dstIdx), newOp->getResult(dstIdx), stage); // update mapping for loop-carried values (args) - for (OpOperand &operand : yield->getOpOperands()) { + for (OpOperand &operand : yieldOp->getOpOperands()) { if (operand.get() == op->getResult(dstIdx)) setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], newOp->getResult(dstIdx), stage + 1); @@ -187,7 +247,8 @@ scf::ForOp LoopPipeliner::createNewForOp() { // order of new args: // (original args), - // (a at stage[0, numStages-1)), (b at stage[0, numStages-1)) + // for each load result x: + // (x at stage[0, numStages-1)) // (depArgs at stage numStages-1) // (iv at stage numStages-1) SmallVector newLoopArgs; @@ -196,54 +257,64 @@ scf::ForOp LoopPipeliner::createNewForOp() { DenseMap depArgsIdx; for (auto v : forOp.getIterOperands()) newLoopArgs.push_back(v); - size_t aArgIdx = newLoopArgs.size(); - for (int i = 0; i < numStages - 1; ++i) - newLoopArgs.push_back(valueMapping[info.dotOp.a()][i]); - size_t bArgIdx = newLoopArgs.size(); - for (int i = 0; i < numStages - 1; ++i) - newLoopArgs.push_back(valueMapping[info.dotOp.b()][i]); + + size_t loadIdx = newLoopArgs.size(); + for (Value loadOp : loads) + for (int i = 0; i < numStages - 1; ++i) + newLoopArgs.push_back(valueMapping[loadOp][i]); + size_t depArgsBeginIdx = newLoopArgs.size(); for (BlockArgument depArg : depArgs) { depArgsIdx[depArg] = newLoopArgs.size(); newLoopArgs.push_back(valueMapping[depArg][numStages-1]); } + size_t nextIVIdx = newLoopArgs.size(); newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]); for (size_t i = 0; i < newLoopArgs.size(); ++i) assert(newLoopArgs[i]); - // signature of the new ForOp + // llvm::errs() << "mapped load is:\n" << newLoopArgs[loadIdx] << "\n\n"; + + // 1. signature of the new ForOp auto newForOp = builder.create(forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newLoopArgs); - // body of the new ForOp + // 2. body of the new ForOp builder.setInsertionPointToStart(newForOp.getBody()); BlockAndValueMapping mapping; for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); - // mapping.map(info.dotOp.a(), newForOp.getRegionIterArgs()[aArgIdx]); - // mapping.map(info.dotOp.b(), newForOp.getRegionIterArgs()[bArgIdx]); + for (Operation &op : forOp.getBody()->without_terminator()) { Operation *newOp = builder.clone(op, mapping); // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx)); - // TODO: why doesn't mapping work? - if (&op == info.dotOp.getOperation()) { - newOp->setOperand(0, newForOp.getRegionIterArgs()[aArgIdx]); - newOp->setOperand(1, newForOp.getRegionIterArgs()[bArgIdx]); - } } - // prefetch next iteration + + // 3. replace loads with args + for (size_t idx = 0; idx < loads.size(); ++idx) { + Value load = loads[idx]; + mapping.lookup(load).replaceAllUsesWith( + newForOp.getRegionIterArgs()[loadIdx+idx]); + } + + + // 4. prefetch the next iteration SmallVector orderedDeps; - for (Operation &op : forOp.getLoopBody().front()) + for (Operation &op : forOp.getLoopBody().front()) { if (depOps.contains(&op)) orderedDeps.push_back(&op); - assert(depOps.size() == orderedDeps.size() && "depOps contains invalid values"); + else if (loads.contains(op.getResult(0))) + orderedDeps.push_back(&op); + } + assert(depOps.size() + loads.size() == orderedDeps.size() && + "depOps contains invalid values"); BlockAndValueMapping nextMapping; DenseMap depArgsMapping; size_t argIdx = 0; @@ -259,8 +330,9 @@ scf::ForOp LoopPipeliner::createNewForOp() { nextIV.getLoc(), arith::CmpIPredicate::slt, nextIV, newForOp.getUpperBound()); for (Operation *op : orderedDeps) { + Operation *nextOp = nullptr; // update loading mask - if (op == info.aLoadOp.getOperation() || op == info.bLoadOp.getOperation()) { + if (loads.contains(op->getResult(0))) { auto loadOp = llvm::cast(op); Value mask = loadOp.mask(); Value splatCond = builder.create(mask.getLoc(), @@ -272,8 +344,18 @@ scf::ForOp LoopPipeliner::createNewForOp() { // if mask is defined outside the loop, don't update the map more than once if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) nextMapping.map(mask, newMask); + // TODO: more elegant way to do this? + nextOp = builder.create( + op->getLoc(), op->getResult(0).getType(), + nextMapping.lookupOrDefault(loadOp.ptr()), + nextMapping.lookupOrDefault(loadOp.mask()), + nextMapping.lookupOrDefault(loadOp.other()), + loadOp.cache(), loadOp.evict(), loadOp.isVolatile() + ); } - Operation *nextOp = builder.clone(*op, nextMapping); + else + nextOp = builder.clone(*op, nextMapping); + // llvm::errs() << "epilogue cloning...: " << *op << "\n"; // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); @@ -294,12 +376,22 @@ scf::ForOp LoopPipeliner::createNewForOp() { SmallVector yieldValues; for (Value v : forOp.getBody()->getTerminator()->getOperands()) yieldValues.push_back(mapping.lookup(v)); - for (int i = 1; i < numStages - 1; ++i) - yieldValues.push_back(newForOp.getRegionIterArgs()[aArgIdx + i]); - yieldValues.push_back(nextMapping.lookup(info.dotOp.a())); - for (int i = 1; i < numStages - 1; ++i) - yieldValues.push_back(newForOp.getRegionIterArgs()[bArgIdx + i]); - yieldValues.push_back(nextMapping.lookup(info.dotOp.b())); + // for (int i = 1; i < numStages - 1; ++i) + // yieldValues.push_back(newForOp.getRegionIterArgs()[aArgIdx + i]); + // yieldValues.push_back(nextMapping.lookup(info.dotOp.a())); + // for (int i = 1; i < numStages - 1; ++i) + // yieldValues.push_back(newForOp.getRegionIterArgs()[bArgIdx + i]); + // yieldValues.push_back(nextMapping.lookup(info.dotOp.b())); + for (size_t idx = 0; idx < loads.size(); ++idx) { + Value load = loads[idx]; + for (int stage = 1; stage < numStages - 1; ++stage) { + yieldValues.push_back(newForOp.getRegionIterArgs()[ + loadIdx + idx*(numStages-1) + stage-1 + ]); + } + yieldValues.push_back(nextMapping.lookup(load)); + } + for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i])); yieldValues.push_back(nextIV); @@ -328,10 +420,16 @@ struct PipelinePass : public TritonGPUPipelineBase { if (pipeliner.initialize().failed()) return; + // llvm::errs() << "find a loop to pipeline...\n"; pipeliner.emitPrologue(); + // llvm::errs() << "\nprologue emitted\n" + // << *forOp->getParentOp(); scf::ForOp newForOp = pipeliner.createNewForOp(); + // llvm::errs() << "new for created:\n" << newForOp << "\n" + // << "inside:\n" << *newForOp->getParentOp() << "\n"; + // replace the original loop for (unsigned i = 0; i < forOp->getNumResults(); ++i) forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index b1de660f6..8ae549ca3 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -92,17 +92,17 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( }); - // We have requirements for the data layouts - addDynamicallyLegalOp([this](triton::DotOp dotOp) -> bool { - Attribute aEncoding = dotOp.a().getType().cast().getEncoding(); - Attribute bEncoding = dotOp.b().getType().cast().getEncoding(); - if (aEncoding && aEncoding.isa() && - bEncoding && bEncoding.isa()) - return true; - // TODO: we should delete this - if (this->typeConverter.isLegal(dotOp)) - return true; - return false; - }); + // // We have requirements for the data layouts + // addDynamicallyLegalOp([this](triton::DotOp dotOp) -> bool { + // Attribute aEncoding = dotOp.a().getType().cast().getEncoding(); + // Attribute bEncoding = dotOp.b().getType().cast().getEncoding(); + // if (aEncoding && aEncoding.isa() && + // bEncoding && bEncoding.isa()) + // return true; + // // TODO: we should delete this + // if (this->typeConverter.isLegal(dotOp)) + // return true; + // return false; + // }); }