From 7b09b5f9e95e4f43e42fd167bd492e04736f2653 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Tue, 7 Jun 2022 19:34:59 +0800 Subject: [PATCH] the pipeline pass now generates and accepts valid IR --- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 34 +++++++++-------- lib/Dialect/TritonGPU/IR/Dialect.cpp | 11 ++++++ lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 38 ++++++++++++++++--- .../Transforms/TritonGPUConversion.cpp | 36 ++++++++++-------- 4 files changed, 82 insertions(+), 37 deletions(-) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 2171fe058..ee35a6eb8 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -163,16 +163,17 @@ struct TritonDotPattern : public OpConversionPattern { return failure(); Value a = adaptor.a(); Value b = adaptor.b(); - // if (!aEncoding.isa()) { - // Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1); - // auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding); - // a = rewriter.create(a.getLoc(), dstType, a); - // } - // if (!bEncoding.isa()) { - // Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1); - // auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding); - // b = rewriter.create(b.getLoc(), dstType, b); - // } + SmallVector order{1, 0}; + if (!aEncoding.isa()) { + Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order); + auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding); + a = rewriter.create(a.getLoc(), dstType, a); + } + if (!bEncoding.isa()) { + Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order); + auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding); + b = rewriter.create(b.getLoc(), dstType, b); + } auto newDot = rewriter.replaceOpWithNewOp( op, retType, a, b, adaptor.c(), adaptor.allowTF32() ); @@ -323,14 +324,17 @@ void populateSCFPatterns( class ConvertTritonToTritonGPU : public ConvertTritonToTritonGPUBase { - public: + ConvertTritonToTritonGPU(int numWarps) { + this->numWarps = numWarps; + } + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); - // int numThreads = mod.getAttr(); + int numThreads = numWarps * 32; // type converter - TritonGPUTypeConverter typeConverter(context, /*numThreads*/32); + TritonGPUTypeConverter typeConverter(context, numThreads); TritonGPUConversionTarget target(*context, typeConverter); // rewrite patterns RewritePatternSet patterns(context); @@ -350,6 +354,6 @@ public: } std::unique_ptr> -mlir::triton::createConvertTritonToTritonGPUPass() { - return std::make_unique<::ConvertTritonToTritonGPU>(); +mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) { + return std::make_unique<::ConvertTritonToTritonGPU>(numWarps); } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 733e54e86..4380524fe 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -276,6 +276,17 @@ static Type getPointeeType(Type type) { } } +static LogicalResult verify(CopyAsyncOp op) { + Type resType = op.getResult().getType(); + if (auto tensorType = resType.dyn_cast()) { + Attribute encoding = tensorType.getEncoding(); + if (!encoding.isa()) + return op.emitOpError("copy_async should return a shared memory tensor"); + } else + return op.emitOpError("copy_async should return a tensor"); + return success(); +} + #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index ceb7bf4d6..b68276678 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -32,6 +32,8 @@ class LoopPipeliner { /// loads to be pipelined SetVector loads; + /// the value that each load will be mapped to (after layout conversion) + DenseMap loadsMapping; /// value (in loop) => value at stage N DenseMap> valueMapping; @@ -139,6 +141,23 @@ LogicalResult LoopPipeliner::initialize() { break; } } + + // For now, we only pipeline loads that have one covert_layout (to smem) use + // TODO: lift this constraint in the future + if (isCandiate && loadOp.getResult().hasOneUse()) { + isCandiate = false; + Operation *use = *loadOp.getResult().getUsers().begin(); + if (auto convertLayout = llvm::dyn_cast(use)) { + if (auto tensorType = convertLayout.getResult().getType().dyn_cast()) { + if (tensorType.getEncoding().isa()) { + isCandiate = true; + loadsMapping[loadOp] = convertLayout; + } + } + } + } else + isCandiate = false; + if (isCandiate) loads.insert(loadOp); } @@ -202,7 +221,7 @@ void LoopPipeliner::emitPrologue() { // TODO: check if the hardware supports copyasync if (auto loadOp = llvm::dyn_cast(op)) { newOp = builder.create( - op->getLoc(), op->getResult(0).getType(), + op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(), loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(), loadOp.isVolatile() ); @@ -237,7 +256,11 @@ void LoopPipeliner::emitPrologue() { // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { - setValueMapping(op->getResult(dstIdx), newOp->getResult(dstIdx), stage); + Value originalResult = op->getResult(dstIdx); + // copy_async will update the value of its only use + if (loads.contains(originalResult)) + originalResult = loadsMapping[originalResult]; + setValueMapping(originalResult, newOp->getResult(dstIdx), stage); // update mapping for loop-carried values (args) for (OpOperand &operand : yieldOp->getOpOperands()) { if (operand.get() == op->getResult(dstIdx)) @@ -254,7 +277,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { // order of new args: // (original args), - // for each load result x: + // for each load result (after layout conversion) x: // (x at stage[0, numStages-1)) // (depArgs at stage numStages-1) // (iv at stage numStages-1) @@ -268,7 +291,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { size_t loadIdx = newLoopArgs.size(); for (Value loadOp : loads) for (int i = 0; i < numStages - 1; ++i) - newLoopArgs.push_back(valueMapping[loadOp][i]); + newLoopArgs.push_back(valueMapping[loadsMapping[loadOp]][i]); size_t depArgsBeginIdx = newLoopArgs.size(); for (BlockArgument depArg : depArgs) { @@ -295,6 +318,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + // 2.1 clone the loop body, replace original args with args of the new ForOp for (Operation &op : forOp.getBody()->without_terminator()) { Operation *newOp = builder.clone(op, mapping); // update mapping of results @@ -305,7 +329,9 @@ scf::ForOp LoopPipeliner::createNewForOp() { // 3. replace loads with block args (from prologue) for (size_t idx = 0; idx < loads.size(); ++idx) { Value load = loads[idx]; - mapping.lookup(load).replaceAllUsesWith( + assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)"); + Value loadUse = load.getUsers().begin()->getResult(0); + mapping.lookup(loadUse).replaceAllUsesWith( newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]); } @@ -351,7 +377,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { nextMapping.map(mask, newMask); // TODO: more elegant way to do this? nextOp = builder.create( - op->getLoc(), op->getResult(0).getType(), + op->getLoc(), loadsMapping[op->getResult(0)].getType(), nextMapping.lookupOrDefault(loadOp.ptr()), nextMapping.lookupOrDefault(loadOp.mask()), nextMapping.lookupOrDefault(loadOp.other()), diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index d6eb0c329..e2d6a6687 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -2,6 +2,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include +#include using namespace mlir; @@ -22,11 +23,13 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, int64_t rank = tensorType.getRank(); int64_t numElements = tensorType.getNumElements(); - // TODO: we should raise exception here. if (!(numElements >= numThreads)) { - llvm::errs() << tensorType << " has " << numElements << " numElements " - << " smaller than numThreads (" << numThreads << ")"; - assert(false); + SmallVector buffer; + llvm::raw_svector_ostream os(buffer); + os << tensorType << " has " << numElements << " numElements " + << " smaller than numThreads (" << numThreads << ")\n" + << "consider using smaller num-warps\n"; + llvm::report_fatal_error(os.str()); } assert(numElements % numThreads == 0); @@ -35,6 +38,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // Now we assume: // contiguous = 1, order = 0, 1, 2, ..., llvm::SmallVector threadTileSize(rank, 1); // naive layout + // TODO: compute warpTileSize. llvm::SmallVector warpTileSize(rank, 1); llvm::SmallVector blockTileSize(rank); llvm::SmallVector order(rank); @@ -93,17 +97,17 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( }); - // // We have requirements for the data layouts - // addDynamicallyLegalOp([this](triton::DotOp dotOp) -> bool { - // Attribute aEncoding = dotOp.a().getType().cast().getEncoding(); - // Attribute bEncoding = dotOp.b().getType().cast().getEncoding(); - // if (aEncoding && aEncoding.isa() && - // bEncoding && bEncoding.isa()) - // return true; - // // TODO: we should delete this - // if (this->typeConverter.isLegal(dotOp)) - // return true; - // return false; - // }); + // We have requirements for the data layouts + addDynamicallyLegalOp([this](triton::DotOp dotOp) -> bool { + Attribute aEncoding = dotOp.a().getType().cast().getEncoding(); + Attribute bEncoding = dotOp.b().getType().cast().getEncoding(); + if (aEncoding && aEncoding.isa() && + bEncoding && bEncoding.isa()) + return true; + // TODO: we should delete this + if (this->typeConverter.isLegal(dotOp)) + return true; + return false; + }); }