the pipeline pass now generates and accepts valid IR

This commit is contained in:
Yan Da
2022-06-07 19:34:59 +08:00
parent 560e29229b
commit 7b09b5f9e9
4 changed files with 82 additions and 37 deletions

View File

@@ -163,16 +163,17 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
return failure();
Value a = adaptor.a();
Value b = adaptor.b();
// if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
// Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
// auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
// a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
// }
// if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
// Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
// auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
// b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
// }
SmallVector<unsigned, 2> order{1, 0};
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
}
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32()
);
@@ -323,14 +324,17 @@ void populateSCFPatterns(
class ConvertTritonToTritonGPU :
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
public:
ConvertTritonToTritonGPU(int numWarps) {
this->numWarps = numWarps;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
// int numThreads = mod.getAttr();
int numThreads = numWarps * 32;
// type converter
TritonGPUTypeConverter typeConverter(context, /*numThreads*/32);
TritonGPUTypeConverter typeConverter(context, numThreads);
TritonGPUConversionTarget target(*context, typeConverter);
// rewrite patterns
RewritePatternSet patterns(context);
@@ -350,6 +354,6 @@ public:
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::triton::createConvertTritonToTritonGPUPass() {
return std::make_unique<::ConvertTritonToTritonGPU>();
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps);
}

View File

@@ -276,6 +276,17 @@ static Type getPointeeType(Type type) {
}
}
static LogicalResult verify(CopyAsyncOp op) {
Type resType = op.getResult().getType();
if (auto tensorType = resType.dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
if (!encoding.isa<TritonGPUSharedEncodingAttr>())
return op.emitOpError("copy_async should return a shared memory tensor");
} else
return op.emitOpError("copy_async should return a tensor");
return success();
}
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"

View File

@@ -32,6 +32,8 @@ class LoopPipeliner {
/// loads to be pipelined
SetVector<Value> loads;
/// the value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// value (in loop) => value at stage N
DenseMap<Value, SmallVector<Value>> valueMapping;
@@ -139,6 +141,23 @@ LogicalResult LoopPipeliner::initialize() {
break;
}
}
// For now, we only pipeline loads that have one covert_layout (to smem) use
// TODO: lift this constraint in the future
if (isCandiate && loadOp.getResult().hasOneUse()) {
isCandiate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout = llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
isCandiate = true;
loadsMapping[loadOp] = convertLayout;
}
}
}
} else
isCandiate = false;
if (isCandiate)
loads.insert(loadOp);
}
@@ -202,7 +221,7 @@ void LoopPipeliner::emitPrologue() {
// TODO: check if the hardware supports copyasync
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
newOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), op->getResult(0).getType(),
op->getLoc(), loadsMapping[loadOp].getType(),
loadOp.ptr(), loadOp.mask(), loadOp.other(),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
);
@@ -237,7 +256,11 @@ void LoopPipeliner::emitPrologue() {
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
setValueMapping(op->getResult(dstIdx), newOp->getResult(dstIdx), stage);
Value originalResult = op->getResult(dstIdx);
// copy_async will update the value of its only use
if (loads.contains(originalResult))
originalResult = loadsMapping[originalResult];
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
// update mapping for loop-carried values (args)
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx))
@@ -254,7 +277,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// order of new args:
// (original args),
// for each load result x:
// for each load result (after layout conversion) x:
// (x at stage[0, numStages-1))
// (depArgs at stage numStages-1)
// (iv at stage numStages-1)
@@ -268,7 +291,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
size_t loadIdx = newLoopArgs.size();
for (Value loadOp : loads)
for (int i = 0; i < numStages - 1; ++i)
newLoopArgs.push_back(valueMapping[loadOp][i]);
newLoopArgs.push_back(valueMapping[loadsMapping[loadOp]][i]);
size_t depArgsBeginIdx = newLoopArgs.size();
for (BlockArgument depArg : depArgs) {
@@ -295,6 +318,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
// 2.1 clone the loop body, replace original args with args of the new ForOp
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping);
// update mapping of results
@@ -305,7 +329,9 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// 3. replace loads with block args (from prologue)
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
mapping.lookup(load).replaceAllUsesWith(
assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)");
Value loadUse = load.getUsers().begin()->getResult(0);
mapping.lookup(loadUse).replaceAllUsesWith(
newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]);
}
@@ -351,7 +377,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
nextMapping.map(mask, newMask);
// TODO: more elegant way to do this?
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), op->getResult(0).getType(),
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()),
nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()),

View File

@@ -2,6 +2,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
using namespace mlir;
@@ -22,11 +23,13 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int64_t rank = tensorType.getRank();
int64_t numElements = tensorType.getNumElements();
// TODO: we should raise exception here.
if (!(numElements >= numThreads)) {
llvm::errs() << tensorType << " has " << numElements << " numElements "
<< " smaller than numThreads (" << numThreads << ")";
assert(false);
SmallVector<char> buffer;
llvm::raw_svector_ostream os(buffer);
os << tensorType << " has " << numElements << " numElements "
<< " smaller than numThreads (" << numThreads << ")\n"
<< "consider using smaller num-warps\n";
llvm::report_fatal_error(os.str());
}
assert(numElements % numThreads == 0);
@@ -35,6 +38,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// Now we assume:
// contiguous = 1, order = 0, 1, 2, ...,
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
// TODO: compute warpTileSize.
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank);
@@ -93,17 +97,17 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
});
// // We have requirements for the data layouts
// addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
// Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
// Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
// if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
// bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
// return true;
// // TODO: we should delete this
// if (this->typeConverter.isLegal(dotOp))
// return true;
// return false;
// });
// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return true;
// TODO: we should delete this
if (this->typeConverter.isLegal(dotOp))
return true;
return false;
});
}