the pipeline pass now generates and accepts valid IR

This commit is contained in:
Yan Da
2022-06-07 19:34:59 +08:00
parent 560e29229b
commit 7b09b5f9e9
4 changed files with 82 additions and 37 deletions

View File

@@ -163,16 +163,17 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
return failure(); return failure();
Value a = adaptor.a(); Value a = adaptor.a();
Value b = adaptor.b(); Value b = adaptor.b();
// if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) { SmallVector<unsigned, 2> order{1, 0};
// Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1); if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
// auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding); Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
// a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a); auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
// } a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
// if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) { }
// Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1); if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
// auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding); Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
// b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b); auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
// } b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>( auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32() op, retType, a, b, adaptor.c(), adaptor.allowTF32()
); );
@@ -323,14 +324,17 @@ void populateSCFPatterns(
class ConvertTritonToTritonGPU : class ConvertTritonToTritonGPU :
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> { public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
public: public:
ConvertTritonToTritonGPU(int numWarps) {
this->numWarps = numWarps;
}
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
ModuleOp mod = getOperation(); ModuleOp mod = getOperation();
// int numThreads = mod.getAttr(); int numThreads = numWarps * 32;
// type converter // type converter
TritonGPUTypeConverter typeConverter(context, /*numThreads*/32); TritonGPUTypeConverter typeConverter(context, numThreads);
TritonGPUConversionTarget target(*context, typeConverter); TritonGPUConversionTarget target(*context, typeConverter);
// rewrite patterns // rewrite patterns
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
@@ -350,6 +354,6 @@ public:
} }
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::triton::createConvertTritonToTritonGPUPass() { mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
return std::make_unique<::ConvertTritonToTritonGPU>(); return std::make_unique<::ConvertTritonToTritonGPU>(numWarps);
} }

View File

@@ -276,6 +276,17 @@ static Type getPointeeType(Type type) {
} }
} }
static LogicalResult verify(CopyAsyncOp op) {
Type resType = op.getResult().getType();
if (auto tensorType = resType.dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
if (!encoding.isa<TritonGPUSharedEncodingAttr>())
return op.emitOpError("copy_async should return a shared memory tensor");
} else
return op.emitOpError("copy_async should return a tensor");
return success();
}
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"

View File

@@ -32,6 +32,8 @@ class LoopPipeliner {
/// loads to be pipelined /// loads to be pipelined
SetVector<Value> loads; SetVector<Value> loads;
/// the value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// value (in loop) => value at stage N /// value (in loop) => value at stage N
DenseMap<Value, SmallVector<Value>> valueMapping; DenseMap<Value, SmallVector<Value>> valueMapping;
@@ -139,6 +141,23 @@ LogicalResult LoopPipeliner::initialize() {
break; break;
} }
} }
// For now, we only pipeline loads that have one covert_layout (to smem) use
// TODO: lift this constraint in the future
if (isCandiate && loadOp.getResult().hasOneUse()) {
isCandiate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout = llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
isCandiate = true;
loadsMapping[loadOp] = convertLayout;
}
}
}
} else
isCandiate = false;
if (isCandiate) if (isCandiate)
loads.insert(loadOp); loads.insert(loadOp);
} }
@@ -202,7 +221,7 @@ void LoopPipeliner::emitPrologue() {
// TODO: check if the hardware supports copyasync // TODO: check if the hardware supports copyasync
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) { if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
newOp = builder.create<triton::gpu::CopyAsyncOp>( newOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), op->getResult(0).getType(), op->getLoc(), loadsMapping[loadOp].getType(),
loadOp.ptr(), loadOp.mask(), loadOp.other(), loadOp.ptr(), loadOp.mask(), loadOp.other(),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile() loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
); );
@@ -237,7 +256,11 @@ void LoopPipeliner::emitPrologue() {
// update mapping of results // update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
setValueMapping(op->getResult(dstIdx), newOp->getResult(dstIdx), stage); Value originalResult = op->getResult(dstIdx);
// copy_async will update the value of its only use
if (loads.contains(originalResult))
originalResult = loadsMapping[originalResult];
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
// update mapping for loop-carried values (args) // update mapping for loop-carried values (args)
for (OpOperand &operand : yieldOp->getOpOperands()) { for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx)) if (operand.get() == op->getResult(dstIdx))
@@ -254,7 +277,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// order of new args: // order of new args:
// (original args), // (original args),
// for each load result x: // for each load result (after layout conversion) x:
// (x at stage[0, numStages-1)) // (x at stage[0, numStages-1))
// (depArgs at stage numStages-1) // (depArgs at stage numStages-1)
// (iv at stage numStages-1) // (iv at stage numStages-1)
@@ -268,7 +291,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
size_t loadIdx = newLoopArgs.size(); size_t loadIdx = newLoopArgs.size();
for (Value loadOp : loads) for (Value loadOp : loads)
for (int i = 0; i < numStages - 1; ++i) for (int i = 0; i < numStages - 1; ++i)
newLoopArgs.push_back(valueMapping[loadOp][i]); newLoopArgs.push_back(valueMapping[loadsMapping[loadOp]][i]);
size_t depArgsBeginIdx = newLoopArgs.size(); size_t depArgsBeginIdx = newLoopArgs.size();
for (BlockArgument depArg : depArgs) { for (BlockArgument depArg : depArgs) {
@@ -295,6 +318,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
// 2.1 clone the loop body, replace original args with args of the new ForOp
for (Operation &op : forOp.getBody()->without_terminator()) { for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping); Operation *newOp = builder.clone(op, mapping);
// update mapping of results // update mapping of results
@@ -305,7 +329,9 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// 3. replace loads with block args (from prologue) // 3. replace loads with block args (from prologue)
for (size_t idx = 0; idx < loads.size(); ++idx) { for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx]; Value load = loads[idx];
mapping.lookup(load).replaceAllUsesWith( assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)");
Value loadUse = load.getUsers().begin()->getResult(0);
mapping.lookup(loadUse).replaceAllUsesWith(
newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]); newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]);
} }
@@ -351,7 +377,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
nextMapping.map(mask, newMask); nextMapping.map(mask, newMask);
// TODO: more elegant way to do this? // TODO: more elegant way to do this?
nextOp = builder.create<triton::gpu::CopyAsyncOp>( nextOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), op->getResult(0).getType(), op->getLoc(), loadsMapping[op->getResult(0)].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()), nextMapping.lookupOrDefault(loadOp.ptr()),
nextMapping.lookupOrDefault(loadOp.mask()), nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()), nextMapping.lookupOrDefault(loadOp.other()),

View File

@@ -2,6 +2,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm> #include <algorithm>
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
using namespace mlir; using namespace mlir;
@@ -22,11 +23,13 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int64_t rank = tensorType.getRank(); int64_t rank = tensorType.getRank();
int64_t numElements = tensorType.getNumElements(); int64_t numElements = tensorType.getNumElements();
// TODO: we should raise exception here.
if (!(numElements >= numThreads)) { if (!(numElements >= numThreads)) {
llvm::errs() << tensorType << " has " << numElements << " numElements " SmallVector<char> buffer;
<< " smaller than numThreads (" << numThreads << ")"; llvm::raw_svector_ostream os(buffer);
assert(false); os << tensorType << " has " << numElements << " numElements "
<< " smaller than numThreads (" << numThreads << ")\n"
<< "consider using smaller num-warps\n";
llvm::report_fatal_error(os.str());
} }
assert(numElements % numThreads == 0); assert(numElements % numThreads == 0);
@@ -35,6 +38,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// Now we assume: // Now we assume:
// contiguous = 1, order = 0, 1, 2, ..., // contiguous = 1, order = 0, 1, 2, ...,
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
// TODO: compute warpTileSize.
llvm::SmallVector<unsigned> warpTileSize(rank, 1); llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank); llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank); llvm::SmallVector<unsigned> order(rank);
@@ -93,17 +97,17 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
}); });
// // We have requirements for the data layouts // We have requirements for the data layouts
// addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool { addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
// Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding(); Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
// Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding(); Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
// if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() && if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
// bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
// return true; return true;
// // TODO: we should delete this // TODO: we should delete this
// if (this->typeConverter.isLegal(dotOp)) if (this->typeConverter.isLegal(dotOp))
// return true; return true;
// return false; return false;
// }); });
} }