the pipeline pass now generates and accepts valid IR
This commit is contained in:
@@ -163,16 +163,17 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
return failure();
|
||||
Value a = adaptor.a();
|
||||
Value b = adaptor.b();
|
||||
// if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
// Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
|
||||
// auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
|
||||
// a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||
// }
|
||||
// if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
// Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
|
||||
// auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
|
||||
// b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
// }
|
||||
SmallVector<unsigned, 2> order{1, 0};
|
||||
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||
}
|
||||
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32()
|
||||
);
|
||||
@@ -323,14 +324,17 @@ void populateSCFPatterns(
|
||||
|
||||
class ConvertTritonToTritonGPU :
|
||||
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||
|
||||
public:
|
||||
ConvertTritonToTritonGPU(int numWarps) {
|
||||
this->numWarps = numWarps;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp mod = getOperation();
|
||||
// int numThreads = mod.getAttr();
|
||||
int numThreads = numWarps * 32;
|
||||
// type converter
|
||||
TritonGPUTypeConverter typeConverter(context, /*numThreads*/32);
|
||||
TritonGPUTypeConverter typeConverter(context, numThreads);
|
||||
TritonGPUConversionTarget target(*context, typeConverter);
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
@@ -350,6 +354,6 @@ public:
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::triton::createConvertTritonToTritonGPUPass() {
|
||||
return std::make_unique<::ConvertTritonToTritonGPU>();
|
||||
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
|
||||
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps);
|
||||
}
|
||||
|
@@ -276,6 +276,17 @@ static Type getPointeeType(Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
static LogicalResult verify(CopyAsyncOp op) {
|
||||
Type resType = op.getResult().getType();
|
||||
if (auto tensorType = resType.dyn_cast<RankedTensorType>()) {
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
if (!encoding.isa<TritonGPUSharedEncodingAttr>())
|
||||
return op.emitOpError("copy_async should return a shared memory tensor");
|
||||
} else
|
||||
return op.emitOpError("copy_async should return a tensor");
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
|
||||
|
@@ -32,6 +32,8 @@ class LoopPipeliner {
|
||||
|
||||
/// loads to be pipelined
|
||||
SetVector<Value> loads;
|
||||
/// the value that each load will be mapped to (after layout conversion)
|
||||
DenseMap<Value, Value> loadsMapping;
|
||||
|
||||
/// value (in loop) => value at stage N
|
||||
DenseMap<Value, SmallVector<Value>> valueMapping;
|
||||
@@ -139,6 +141,23 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// For now, we only pipeline loads that have one covert_layout (to smem) use
|
||||
// TODO: lift this constraint in the future
|
||||
if (isCandiate && loadOp.getResult().hasOneUse()) {
|
||||
isCandiate = false;
|
||||
Operation *use = *loadOp.getResult().getUsers().begin();
|
||||
if (auto convertLayout = llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
|
||||
if (auto tensorType = convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
|
||||
if (tensorType.getEncoding().isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
isCandiate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else
|
||||
isCandiate = false;
|
||||
|
||||
if (isCandiate)
|
||||
loads.insert(loadOp);
|
||||
}
|
||||
@@ -202,7 +221,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
// TODO: check if the hardware supports copyasync
|
||||
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
||||
newOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||
op->getLoc(), op->getResult(0).getType(),
|
||||
op->getLoc(), loadsMapping[loadOp].getType(),
|
||||
loadOp.ptr(), loadOp.mask(), loadOp.other(),
|
||||
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
|
||||
);
|
||||
@@ -237,7 +256,11 @@ void LoopPipeliner::emitPrologue() {
|
||||
|
||||
// update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
setValueMapping(op->getResult(dstIdx), newOp->getResult(dstIdx), stage);
|
||||
Value originalResult = op->getResult(dstIdx);
|
||||
// copy_async will update the value of its only use
|
||||
if (loads.contains(originalResult))
|
||||
originalResult = loadsMapping[originalResult];
|
||||
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
|
||||
// update mapping for loop-carried values (args)
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
if (operand.get() == op->getResult(dstIdx))
|
||||
@@ -254,7 +277,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
|
||||
// order of new args:
|
||||
// (original args),
|
||||
// for each load result x:
|
||||
// for each load result (after layout conversion) x:
|
||||
// (x at stage[0, numStages-1))
|
||||
// (depArgs at stage numStages-1)
|
||||
// (iv at stage numStages-1)
|
||||
@@ -268,7 +291,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
size_t loadIdx = newLoopArgs.size();
|
||||
for (Value loadOp : loads)
|
||||
for (int i = 0; i < numStages - 1; ++i)
|
||||
newLoopArgs.push_back(valueMapping[loadOp][i]);
|
||||
newLoopArgs.push_back(valueMapping[loadsMapping[loadOp]][i]);
|
||||
|
||||
size_t depArgsBeginIdx = newLoopArgs.size();
|
||||
for (BlockArgument depArg : depArgs) {
|
||||
@@ -295,6 +318,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
|
||||
// 2.1 clone the loop body, replace original args with args of the new ForOp
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
Operation *newOp = builder.clone(op, mapping);
|
||||
// update mapping of results
|
||||
@@ -305,7 +329,9 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
// 3. replace loads with block args (from prologue)
|
||||
for (size_t idx = 0; idx < loads.size(); ++idx) {
|
||||
Value load = loads[idx];
|
||||
mapping.lookup(load).replaceAllUsesWith(
|
||||
assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)");
|
||||
Value loadUse = load.getUsers().begin()->getResult(0);
|
||||
mapping.lookup(loadUse).replaceAllUsesWith(
|
||||
newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]);
|
||||
}
|
||||
|
||||
@@ -351,7 +377,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
nextMapping.map(mask, newMask);
|
||||
// TODO: more elegant way to do this?
|
||||
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||
op->getLoc(), op->getResult(0).getType(),
|
||||
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
|
||||
nextMapping.lookupOrDefault(loadOp.ptr()),
|
||||
nextMapping.lookupOrDefault(loadOp.mask()),
|
||||
nextMapping.lookupOrDefault(loadOp.other()),
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <algorithm>
|
||||
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -22,11 +23,13 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
int64_t rank = tensorType.getRank();
|
||||
int64_t numElements = tensorType.getNumElements();
|
||||
|
||||
// TODO: we should raise exception here.
|
||||
if (!(numElements >= numThreads)) {
|
||||
llvm::errs() << tensorType << " has " << numElements << " numElements "
|
||||
<< " smaller than numThreads (" << numThreads << ")";
|
||||
assert(false);
|
||||
SmallVector<char> buffer;
|
||||
llvm::raw_svector_ostream os(buffer);
|
||||
os << tensorType << " has " << numElements << " numElements "
|
||||
<< " smaller than numThreads (" << numThreads << ")\n"
|
||||
<< "consider using smaller num-warps\n";
|
||||
llvm::report_fatal_error(os.str());
|
||||
}
|
||||
assert(numElements % numThreads == 0);
|
||||
|
||||
@@ -35,6 +38,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// Now we assume:
|
||||
// contiguous = 1, order = 0, 1, 2, ...,
|
||||
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
|
||||
// TODO: compute warpTileSize.
|
||||
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
|
||||
llvm::SmallVector<unsigned> blockTileSize(rank);
|
||||
llvm::SmallVector<unsigned> order(rank);
|
||||
@@ -93,17 +97,17 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
});
|
||||
|
||||
|
||||
// // We have requirements for the data layouts
|
||||
// addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
|
||||
// Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
// Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||
// if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
// bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
// return true;
|
||||
// // TODO: we should delete this
|
||||
// if (this->typeConverter.isLegal(dotOp))
|
||||
// return true;
|
||||
// return false;
|
||||
// });
|
||||
// We have requirements for the data layouts
|
||||
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
|
||||
Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
return true;
|
||||
// TODO: we should delete this
|
||||
if (this->typeConverter.isLegal(dotOp))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user