Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)

This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
Philippe Tillet
2022-12-21 01:30:50 -08:00
committed by GitHub
parent 8650b4d1cb
commit 20100a7254
285 changed files with 26312 additions and 50143 deletions

View File

@@ -0,0 +1,21 @@
set(LLVM_TARGET_DEFINITIONS Combine.td)
mlir_tablegen(TritonGPUCombine.inc -gen-rewriters)
add_public_tablegen_target(TritonGPUCombineIncGen)
add_mlir_dialect_library(TritonGPUTransforms
Coalesce.cpp
CanonicalizeLoops.cpp
Combine.cpp
Pipeline.cpp
Prefetch.cpp
TritonGPUConversion.cpp
DEPENDS
TritonGPUTransformsIncGen
TritonGPUCombineIncGen
LINK_LIBS PUBLIC
TritonIR
TritonGPUIR
MLIRTransformUtils
)

View File

@@ -0,0 +1,55 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace {
struct CanonicalizePass
: public TritonGPUCanonicalizeLoopsBase<CanonicalizePass> {
CanonicalizePass() = default;
void runOnOperation() override {
// Canonicalize pass may have created dead code that
// standard scf.for canonicalization cannot handle
// as of LLVM 14. For example, the iteration arguments
// for the pointer of the synchronous loads that are
// discarded.
// The following piece of code is a workaround to
// very crudely remove dead code, by making an iteration
// argument yield itself if it is not used to create
// side effects anywhere.
getOperation()->walk([&](scf::ForOp forOp) -> void {
for (size_t i = 0; i < forOp.getNumResults(); ++i) {
// condition 1: no other iter arguments depend on it
SetVector<Operation *> fwdSlice;
mlir::getForwardSlice(forOp.getRegionIterArgs()[i], &fwdSlice);
Operation *yieldOp = forOp.getBody()->getTerminator();
bool noOtherDependency = std::all_of(
yieldOp->operand_begin(), yieldOp->operand_end(), [&](Value arg) {
return arg == yieldOp->getOperand(i) ||
!fwdSlice.contains(arg.getDefiningOp());
});
// condition 2: final value is not used after the loop
auto retVal = forOp.getResult(i);
bool noUserAfterLoop = retVal.getUsers().empty();
// yielding the region iter arg will cause loop canonicalization
// to clean up the dead code
if (noOtherDependency && noUserAfterLoop) {
yieldOp->setOperand(i, forOp.getRegionIterArgs()[i]);
}
}
});
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUCanonicalizeLoopsPass() {
return std::make_unique<CanonicalizePass>();
}

View File

@@ -0,0 +1,139 @@
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <numeric>
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr,
int numWarps) {
auto origType = ptr.getType().cast<RankedTensorType>();
// Get the shape of the tensor.
size_t rank = origType.getRank();
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
// Layout order in decreasing order of contiguity
SmallVector<unsigned, 4> order(rank);
std::iota(order.begin(), order.end(), 0);
auto contiguity = info.getContiguity();
std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) {
return contiguity[x] > contiguity[y];
});
int numElems = product(origType.getShape());
int numThreads = numWarps * 32;
int numElemsPerThread = std::max(numElems / numThreads, 1);
// Thread tile size depends on memory alignment
SmallVector<unsigned, 4> sizePerThread(rank, 1);
PointerType ptrType = origType.getElementType().cast<PointerType>();
auto pointeeType = ptrType.getPointeeType();
unsigned numBits = pointeeType.isa<triton::Float8Type>()
? 8
: pointeeType.getIntOrFloatBitWidth();
unsigned maxMultiple = info.getDivisibility(order[0]);
unsigned maxContig = info.getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
unsigned perThread = std::min(alignment, 128 / numBits);
sizePerThread[order[0]] = std::min<int>(perThread, numElemsPerThread);
SmallVector<unsigned> dims(rank);
std::iota(dims.begin(), dims.end(), 0);
// create encoding
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
&getContext(), origType.getShape(), sizePerThread, order, numWarps);
return encoding;
}
std::function<Type(Type)> getTypeConverter(AxisInfoAnalysis &axisInfo,
Value ptr, int numWarps) {
Attribute encoding = getCoalescedEncoding(axisInfo, ptr, numWarps);
return [encoding](Type _type) {
RankedTensorType type = _type.cast<RankedTensorType>();
return RankedTensorType::get(type.getShape(), type.getElementType(),
encoding);
};
}
template <class T>
void coalesceOp(AxisInfoAnalysis &axisInfo, Operation *op, Value ptr,
OpBuilder builder) {
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
if (!ty)
return;
auto mod = op->getParentOfType<ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
auto convertType = getTypeConverter(axisInfo, ptr, numWarps);
// convert operands
SmallVector<Value, 4> newArgs;
for (auto v : op->getOperands()) {
auto vTy = v.getType().dyn_cast<RankedTensorType>();
if (vTy && !vTy.getEncoding().isa<triton::gpu::SharedEncodingAttr>())
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), convertType(v.getType()), v));
else
newArgs.push_back(v);
}
// convert output types
SmallVector<Type, 4> newTypes;
for (auto t : op->getResultTypes()) {
bool is_async = std::is_same<T, triton::gpu::InsertSliceAsyncOp>::value;
newTypes.push_back(is_async ? t : convertType(t));
}
// construct new op with the new encoding
Operation *newOp =
builder.create<T>(op->getLoc(), newTypes, newArgs, op->getAttrs());
// cast the results back to the original layout
for (size_t i = 0; i < op->getNumResults(); i++) {
Value newResult = newOp->getResult(i);
if (newTypes[i] != op->getResultTypes()[i]) {
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), op->getResult(i).getType(), newResult);
}
op->getResult(i).replaceAllUsesWith(newResult);
}
op->erase();
}
void runOnOperation() override {
Operation *op = getOperation();
// Run axis info analysis
AxisInfoAnalysis axisInfo(&getContext());
axisInfo.run(op);
OpBuilder builder(op);
// For each memory op that has a layout L1:
// 1. Create a coalesced memory layout L2 of the pointer operands
// 2. Convert all operands from layout L1 to layout L2
// 3. Create a new memory op that consumes these operands and
// produces a tensor with layout L2
// 4. Convert the output of this new memory op back to L1
// 5. Replace all the uses of the original memory op by the new one
op->walk([&](Operation *curr) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(curr);
if (auto load = dyn_cast<triton::LoadOp>(curr))
coalesceOp<triton::LoadOp>(axisInfo, curr, load.ptr(), builder);
if (auto op = dyn_cast<triton::AtomicRMWOp>(curr))
coalesceOp<triton::AtomicRMWOp>(axisInfo, curr, op.ptr(), builder);
if (auto op = dyn_cast<triton::AtomicCASOp>(curr))
coalesceOp<triton::AtomicCASOp>(axisInfo, curr, op.ptr(), builder);
if (auto load = dyn_cast<triton::gpu::InsertSliceAsyncOp>(curr))
coalesceOp<triton::gpu::InsertSliceAsyncOp>(axisInfo, curr, load.src(),
builder);
if (auto store = dyn_cast<triton::StoreOp>(curr))
coalesceOp<triton::StoreOp>(axisInfo, curr, store.ptr(), builder);
});
}
};
std::unique_ptr<Pass> mlir::createTritonGPUCoalescePass() {
return std::make_unique<CoalescePass>();
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
#ifndef TRITONGPU_PATTERNS
#define TRITONGPU_PATTERNS
include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td"
include "triton/Dialect/Triton/IR/TritonOps.td"
#endif

View File

@@ -0,0 +1,656 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
//===----------------------------------------------------------------------===//
//
// This file implements loop software pipelining
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
// and SCF's LoopPipelining.
//
//===----------------------------------------------------------------------===//
using namespace mlir;
namespace ttg = triton::gpu;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
static Type getI1SameShape(Value v) {
Type vType = v.getType();
auto i1Type = IntegerType::get(vType.getContext(), 1);
auto tensorType = vType.cast<RankedTensorType>();
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
}
#define int_attr(num) builder.getI64IntegerAttr(num)
namespace {
class LoopPipeliner {
/// Cache forOp we are working on
scf::ForOp forOp;
/// Cache YieldOp for this forOp
scf::YieldOp yieldOp;
/// Loads to be pipelined
SetVector<Value> loads;
/// The value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// load => buffer
DenseMap<Value, Value> loadsBuffer;
/// load => buffer type (with shared layout after swizzling)
DenseMap<Value, RankedTensorType> loadsBufferType;
/// load => buffer at stage N
DenseMap<Value, SmallVector<Value>> loadStageBuffer;
/// load => after extract
DenseMap<Value, Value> loadsExtract;
///
Value pipelineIterIdx;
///
Value loopIterIdx;
/// Comments on numStages:
/// [0, numStages-1) are in the prologue
/// numStages-1 is appended after the loop body
int numStages;
/// value (in loop) => value at stage N
DenseMap<Value, SmallVector<Value>> valueMapping;
/// Block arguments that loads depend on
DenseSet<BlockArgument> depArgs;
/// Operations (inside the loop body) that loads depend on
DenseSet<Operation *> depOps;
/// collect values that v depends on and are defined inside the loop
void collectDeps(Value v, int stages, DenseSet<Value> &deps);
void setValueMapping(Value origin, Value newValue, int stage);
Value lookupOrDefault(Value origin, int stage);
/// Returns a empty buffer of size <numStages, ...>
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
public:
LoopPipeliner(scf::ForOp forOp, int numStages)
: forOp(forOp), numStages(numStages) {
// cache yieldOp
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
}
/// Collect loads to pipeline. Return success if we can pipeline this loop
LogicalResult initialize();
/// Emit pipelined loads (before loop body)
void emitPrologue();
/// emit pipelined loads (after loop body)
void emitEpilogue();
/// create the new ForOp (add new args & insert prefetched ops)
scf::ForOp createNewForOp();
friend struct PipelinePass;
};
// helpers
void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) {
if (valueMapping.find(origin) == valueMapping.end())
valueMapping[origin] = SmallVector<Value>(numStages);
valueMapping[origin][stage] = newValue;
}
Value LoopPipeliner::lookupOrDefault(Value origin, int stage) {
if (valueMapping.find(origin) == valueMapping.end())
return origin;
return valueMapping[origin][stage];
}
void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
// Loop-invariant value, skip
if (v.getParentRegion() != &forOp.getLoopBody())
return;
// Since we only need to peel the loop numStages-1 times, don't worry about
// depends that are too far away
if (stages < 0)
return;
if (auto arg = v.dyn_cast<BlockArgument>()) {
if (arg.getArgNumber() > 0) {
// Skip the first arg (loop induction variable)
// Otherwise the op idx is arg.getArgNumber()-1
deps.insert(v);
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1,
deps);
}
} else { // value
// v might be in deps, but we still need to visit v.
// This is because v might depend on value in previous iterations
deps.insert(v);
for (Value op : v.getDefiningOp()->getOperands())
collectDeps(op, stages, deps);
}
}
ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
OpBuilder &builder) {
// Allocate a buffer for each pipelined tensor
// shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16>
Value convertLayout = loadsMapping[op->getResult(0)];
if (auto tensorType = convertLayout.getType().dyn_cast<RankedTensorType>()) {
return builder.create<ttg::AllocTensorOp>(
convertLayout.getLoc(), loadsBufferType[op->getResult(0)]);
}
llvm_unreachable("Async copy's return should be of RankedTensorType");
}
/// A load instruction can be pipelined if:
/// - the load doesn't depend on any other loads (after loop peeling)
/// - (?) this load is not a loop-invariant value (we should run LICM before
/// this pass?)
LogicalResult LoopPipeliner::initialize() {
Block *loop = forOp.getBody();
// can we use forOp.walk(...) here?
SmallVector<triton::LoadOp, 2> allLoads;
for (Operation &op : *loop)
if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
allLoads.push_back(loadOp);
// Early stop: no need to continue if there is no load in the loop.
if (allLoads.empty())
return failure();
// load => values that it depends on
DenseMap<Value, DenseSet<Value>> loadDeps;
for (triton::LoadOp loadOp : allLoads) {
DenseSet<Value> deps;
for (Value op : loadOp->getOperands())
collectDeps(op, numStages - 1, deps);
loadDeps[loadOp] = deps;
}
// Don't pipeline loads that depend on other loads
// (Because if a load depends on another load, this load needs to wait on the
// other load in the prologue, which is against the point of the pipeline
// pass)
for (triton::LoadOp loadOp : allLoads) {
bool isCandidate = true;
for (triton::LoadOp other : allLoads) {
if (loadDeps[loadOp].contains(other)) {
isCandidate = false;
break;
}
}
// We only pipeline loads that have one covert_layout (to dot_op) use
// TODO: lift this constraint in the future
if (isCandidate && loadOp.getResult().hasOneUse()) {
isCandidate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>()) {
if (auto dotOpEnc = tensorType.getEncoding()
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
isCandidate = true;
loadsMapping[loadOp] = convertLayout;
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
triton::gpu::getOrder(ty.getEncoding()), ty.getElementType());
loadsBufferType[loadOp] = RankedTensorType::get(
bufferShape, ty.getElementType(), sharedEnc);
}
}
}
} else
isCandidate = false;
if (isCandidate)
loads.insert(loadOp);
}
// We have some loads to pipeline
if (!loads.empty()) {
// Update depArgs & depOps
for (Value loadOp : loads) {
for (Value dep : loadDeps[loadOp]) {
// TODO: we should record the stage that the value is depended on
if (auto arg = dep.dyn_cast<BlockArgument>())
depArgs.insert(arg);
else
depOps.insert(dep.getDefiningOp());
}
}
return success();
}
return failure();
}
void LoopPipeliner::emitPrologue() {
// llvm::errs() << "loads to pipeline...:\n";
// for (Value load : loads)
// llvm::errs() << load << "\n";
OpBuilder builder(forOp);
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
setValueMapping(arg, operand.get(), 0);
}
// prologue from [0, numStage-1)
Value iv = forOp.getLowerBound();
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (int stage = 0; stage < numStages - 1; ++stage) {
// Special handling for induction variable as the increment is implicit
if (stage != 0)
iv = builder.create<arith::AddIOp>(iv.getLoc(), iv, forOp.getStep());
setValueMapping(forOp.getInductionVar(), iv, stage);
// Special handling for loop condition as there is no condition in ForOp
Value loopCond = builder.create<arith::CmpIOp>(
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
// Rematerialize peeled values
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
else if (loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
"depOps contains invalid values");
for (Operation *op : orderedDeps) {
Operation *newOp = nullptr;
if (loads.contains(op->getResult(0))) {
// Allocate empty buffer
if (stage == 0) {
loadsBuffer[op->getResult(0)] = allocateEmptyBuffer(op, builder);
loadStageBuffer[op->getResult(0)] = {loadsBuffer[op->getResult(0)]};
}
// load => copy async
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
Value mask = lookupOrDefault(loadOp.mask(), stage);
Value newMask;
if (mask) {
Value splatCond = builder.create<triton::SplatOp>(
mask.getLoc(), mask.getType(), loopCond);
newMask =
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
} else {
newMask = builder.create<triton::SplatOp>(
loopCond.getLoc(), getI1SameShape(loadOp), loopCond);
}
// TODO: check if the hardware supports async copy
newOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
op->getLoc(), loadsBuffer[loadOp].getType(),
lookupOrDefault(loadOp.ptr(), stage),
loadStageBuffer[loadOp][stage], pipelineIterIdx, newMask,
lookupOrDefault(loadOp.other(), stage), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
loadStageBuffer[loadOp].push_back(newOp->getResult(0));
} else
llvm_unreachable("This should be LoadOp");
} else {
newOp = builder.clone(*op);
// Update loop-carried uses
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
auto it = valueMapping.find(op->getOperand(opIdx));
if (it != valueMapping.end()) {
Value v = it->second[stage];
assert(v);
newOp->setOperand(opIdx, v);
} // else, op at opIdx is a loop-invariant value
}
}
// Update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
Value originalResult = op->getResult(dstIdx);
// copy_async will update the value of its only use
// TODO: load should not be used in the preheader?
if (loads.contains(originalResult)) {
break;
// originalResult = loadsMapping[originalResult];
}
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
// update mapping for loop-carried values (args)
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx))
setValueMapping(
forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(dstIdx), stage + 1);
}
}
} // for (Operation *op : orderedDeps)
pipelineIterIdx = builder.create<arith::AddIOp>(
iv.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
} // for (int stage = 0; stage < numStages - 1; ++stage)
// async.wait & extract_slice
builder.create<ttg::AsyncWaitOp>(loads[0].getLoc(),
loads.size() * (numStages - 2));
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (Value loadOp : loads) {
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
sliceType =
RankedTensorType::get(sliceType.getShape(), sliceType.getElementType(),
loadsBufferType[loadOp].getEncoding());
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1),
int_attr(sliceType.getShape()[0]),
int_attr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
loadsExtract[loadOp] = extractSlice;
}
// Bump up loopIterIdx, this is used for getting the correct slice for the
// *next* iteration
loopIterIdx = builder.create<arith::AddIOp>(
loopIterIdx.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(loopIterIdx.getLoc(), 1, 32));
}
void LoopPipeliner::emitEpilogue() {
// If there's any outstanding async copies, we need to wait for them.
OpBuilder builder(forOp);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointAfter(forOp);
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
}
scf::ForOp LoopPipeliner::createNewForOp() {
OpBuilder builder(forOp);
// Order of new args:
// (original args)
// (insertSliceAsync buffer at stage numStages - 1) for each load
// (extracted tensor) for each load
// (depArgs at stage numStages - 1)
// (iv at stage numStages - 2)
// (pipeline iteration index)
// (loop iteration index)
SmallVector<Value> newLoopArgs;
// We need this to update operands for yield
// original block arg => new arg's idx
DenseMap<BlockArgument, size_t> depArgsIdx;
for (auto v : forOp.getIterOperands())
newLoopArgs.push_back(v);
size_t bufferIdx = newLoopArgs.size();
for (Value loadOp : loads)
newLoopArgs.push_back(loadStageBuffer[loadOp].back());
size_t loadIdx = newLoopArgs.size();
for (Value loadOp : loads)
newLoopArgs.push_back(loadsExtract[loadOp]);
size_t depArgsBeginIdx = newLoopArgs.size();
for (BlockArgument depArg : depArgs) {
depArgsIdx[depArg] = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[depArg][numStages - 1]);
}
size_t nextIVIdx = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
newLoopArgs.push_back(pipelineIterIdx);
newLoopArgs.push_back(loopIterIdx);
for (size_t i = 0; i < newLoopArgs.size(); ++i)
assert(newLoopArgs[i]);
// 1. signature of the new ForOp
auto newForOp = builder.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newLoopArgs);
// 2. body of the new ForOp
builder.setInsertionPointToStart(newForOp.getBody());
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
// 2.1 clone the loop body, replace original args with args of the new ForOp
// Insert async wait if necessary.
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping);
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
}
// 3. replace loads with block args (from prologue)
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
assert(load.hasOneUse() &&
"we assume that this load has one use (ConvertLayout)");
Value loadUse = load.getUsers().begin()->getResult(0);
mapping.lookup(loadUse).replaceAllUsesWith(
newForOp.getRegionIterArgs()[loadIdx + idx]);
// delete old load and layout conversion
mapping.lookup(loadUse).getDefiningOp()->erase();
mapping.lookup(load).getDefiningOp()->erase();
}
// 4. prefetch the next iteration
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
else if (loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
"depOps contains invalid values");
BlockAndValueMapping nextMapping;
DenseMap<BlockArgument, Value> depArgsMapping;
size_t argIdx = 0;
for (BlockArgument arg : depArgs) {
nextMapping.map(arg,
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
++argIdx;
}
// Special handling for iv & loop condition
Value nextIV = builder.create<arith::AddIOp>(
newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
Value nextLoopCond =
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
nextMapping.map(forOp.getInductionVar(), nextIV);
// Slice index
SmallVector<Value> nextBuffers;
SmallVector<Value> extractSlices;
pipelineIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 1];
Value insertSliceIndex = builder.create<arith::RemSIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
loopIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 2];
Value extractSliceIndex = builder.create<arith::RemSIOp>(
nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
extractSliceIndex = builder.create<arith::IndexCastOp>(
extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex);
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// Update loading mask
if (loads.contains(op->getResult(0))) {
auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask();
Value newMask;
if (mask) {
Value splatCond = builder.create<triton::SplatOp>(
mask.getLoc(), mask.getType(), nextLoopCond);
newMask = builder.create<arith::AndIOp>(
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
// If mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
newMask = nextMapping.lookupOrDefault(loadOp.mask());
} else
newMask = builder.create<triton::SplatOp>(
loadOp.getLoc(), getI1SameShape(loadOp), nextLoopCond);
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
op->getLoc(), loadsBuffer[loadOp].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()),
newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()],
insertSliceIndex, newMask,
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
nextBuffers.push_back(insertAsyncOp);
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
sliceType = RankedTensorType::get(sliceType.getShape(),
sliceType.getElementType(),
loadsBufferType[loadOp].getEncoding());
nextOp = builder.create<tensor::ExtractSliceOp>(
op->getLoc(), sliceType, insertAsyncOp,
SmallVector<OpFoldResult>{extractSliceIndex, int_attr(0),
int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1),
int_attr(sliceType.getShape()[0]),
int_attr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
extractSlices.push_back(nextOp->getResult(0));
} else
nextOp = builder.clone(*op, nextMapping);
// Update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
// If this is a loop-carried value, update the mapping for yield
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (OpOperand &operand : originYield->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx)) {
size_t originIdx = operand.getOperandNumber();
size_t newArgIdx = depArgsIdx[forOp.getRegionIterArgs()[originIdx]];
BlockArgument newArg = newForOp.getRegionIterArgs()[newArgIdx];
depArgsMapping[newArg] = nextOp->getResult(dstIdx);
}
}
}
}
{
OpBuilder::InsertionGuard guard(builder);
for (Operation &op : *newForOp.getBody()) {
if (auto dotOp = llvm::dyn_cast<triton::DotOp>(&op)) {
builder.setInsertionPoint(&op);
auto dotType = dotOp.getType().cast<RankedTensorType>();
Value a = dotOp.a();
Value b = dotOp.b();
auto layoutCast = [&](Value dotOperand, int opIdx) -> Value {
auto tensorType = dotOperand.getType().cast<RankedTensorType>();
if (!tensorType.getEncoding().isa<ttg::DotOperandEncodingAttr>()) {
auto newEncoding = ttg::DotOperandEncodingAttr::get(
tensorType.getContext(), opIdx, dotType.getEncoding());
auto newType =
RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), newEncoding);
return builder.create<ttg::ConvertLayoutOp>(dotOperand.getLoc(),
newType, dotOperand);
}
return dotOperand;
};
a = layoutCast(a, 0);
b = layoutCast(b, 1);
dotOp->setOperand(0, a);
dotOp->setOperand(1, b);
}
}
}
// async.wait & extract_slice
Operation *asyncWait = builder.create<ttg::AsyncWaitOp>(
loads[0].getLoc(), loads.size() * (numStages - 2));
for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) {
// move extract_slice after asyncWait
it->getDefiningOp()->moveAfter(asyncWait);
}
// Bump iteration count
pipelineIterIdx = builder.create<arith::AddIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
loopIterIdx = builder.create<arith::AddIOp>(
nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
// Finally, the YieldOp, need to sync with the order of newLoopArgs
SmallVector<Value> yieldValues;
for (Value v : forOp.getBody()->getTerminator()->getOperands())
yieldValues.push_back(mapping.lookup(v));
for (Value nextBuffer : nextBuffers)
yieldValues.push_back(nextBuffer);
for (Value nextSlice : extractSlices)
yieldValues.push_back(nextSlice);
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) {
auto arg = newForOp.getRegionIterArgs()[i];
assert(depArgsMapping.count(arg) && "Missing loop-carried value");
yieldValues.push_back(depArgsMapping[arg]);
}
yieldValues.push_back(nextIV);
yieldValues.push_back(pipelineIterIdx);
yieldValues.push_back(loopIterIdx);
builder.setInsertionPointToEnd(newForOp.getBody());
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
yieldValues);
return newForOp;
}
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
PipelinePass() = default;
PipelinePass(int numStages) { this->numStages = numStages; }
void runOnOperation() override {
int numStages = this->numStages;
if (numStages <= 1)
return;
getOperation()->walk([&](scf::ForOp forOp) -> void {
LoopPipeliner pipeliner(forOp, numStages);
if (pipeliner.initialize().failed())
return;
pipeliner.emitPrologue();
scf::ForOp newForOp = pipeliner.createNewForOp();
pipeliner.emitEpilogue();
// replace the original loop
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
forOp->erase();
});
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUPipelinePass(int numStages) {
return std::make_unique<PipelinePass>(numStages);
}

View File

@@ -0,0 +1,313 @@
//===----------------------------------------------------------------------===//
//
// This pass tries to prefetch operands (a and b) of tt.dot.
// Those ConvertLayoutOps will be lowered to shared memory loads.
//
// For example:
// %a: tensor<128x32xf16, #enc>
// scf.for %iv = ... iter_args(%a_arg = %a, ...) {
// %d = tt.dot %a_arg, %b, %c
// ...
// scf.yield %a_next, ...
// }
//
// will be translated to
//
// %a: tensor<128x32xf16, #enc>
// %a_tmp = tensor.extract_slice %a[0, 0] [128, 16]
// %a_prefetch = triton_gpu.convert_layout %a_tmp
// scf.for %iv = ... iter_args(%a_buf = %a, ..., %a_prefetch_arg = %a_prefetch)
// {
// %x = tt.dot %a_arg, %b, %c
// %a_tmp_rem = tensor.extract_slice %a_buf[0, 16] [128, 16]
// %a_prefetch_next = triton_gpu.convert_layout %a_tmp_rem
// ...
// scf.yield %next_a, ..., %a_prefetch_next
// }
//===----------------------------------------------------------------------===//
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
using namespace mlir;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace {
class Prefetcher {
/// cache the ForOp we are working on
scf::ForOp forOp;
/// cache the YieldOp of this ForOp
scf::YieldOp yieldOp;
///
// TODO: add a hook to infer prefetchWidth
unsigned prefetchWidth = 16;
/// dots to be prefetched
SetVector<Value> dots;
/// dot => dot operand
DenseMap<Value, Value> dot2aLoopArg;
DenseMap<Value, Value> dot2aHeaderDef;
DenseMap<Value, Value> dot2bLoopArg;
DenseMap<Value, Value> dot2bHeaderDef;
DenseMap<Value, Value> dot2aYield;
DenseMap<Value, Value> dot2bYield;
/// operand => defining
DenseMap<Value, Value> operand2headPrefetch;
LogicalResult isForOpOperand(Value v);
Value generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
Attribute dotEncoding, OpBuilder &builder,
llvm::Optional<int64_t> offsetK = llvm::None,
llvm::Optional<int64_t> shapeK = llvm::None);
public:
Prefetcher() = delete;
Prefetcher(scf::ForOp forOp) : forOp(forOp) {
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
}
LogicalResult initialize();
void emitPrologue();
scf::ForOp createNewForOp();
};
Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
Attribute dotEncoding, OpBuilder &builder,
llvm::Optional<int64_t> offsetK,
llvm::Optional<int64_t> shapeK) {
// opIdx: 0 => a, 1 => b
auto type = v.getType().cast<RankedTensorType>();
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
SmallVector<int64_t> offset{0, 0};
Type elementType = type.getElementType();
auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); };
// k => (prefetchWidth, k - prefetchWidth)
int64_t kIdx = opIdx == 0 ? 1 : 0;
offset[kIdx] = isPrologue ? 0 : prefetchWidth;
shape[kIdx] = isPrologue ? prefetchWidth : (shape[kIdx] - prefetchWidth);
if (shapeK)
shape[kIdx] = *shapeK;
if (offsetK)
offset[kIdx] = *offsetK;
Value newSmem = builder.create<tensor::ExtractSliceOp>(
v.getLoc(),
// TODO: encoding?
RankedTensorType::get(shape, elementType, type.getEncoding()), v,
SmallVector<OpFoldResult>{intAttr(offset[0]), intAttr(offset[1])},
SmallVector<OpFoldResult>{intAttr(shape[0]), intAttr(shape[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1)});
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding);
Value prefetchSlice = builder.create<triton::gpu::ConvertLayoutOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
return prefetchSlice;
}
LogicalResult Prefetcher::initialize() {
Block *loop = forOp.getBody();
SmallVector<triton::DotOp> dotsInFor;
for (Operation &op : *loop)
if (auto dotOp = dyn_cast<triton::DotOp>(op))
dotsInFor.push_back(dotOp);
if (dotsInFor.empty())
return failure();
// TODO: segfault (original for still has uses)
// when used in flash attention that has 2 dots in the loop
if (dotsInFor.size() > 1)
return failure();
// returns source of cvt
auto getPrefetchSrc = [](Value v) -> Value {
if (auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>())
if (isSharedEncoding(cvt.getOperand()))
return cvt.src();
return Value();
};
auto getIncomingOp = [this](Value v) -> Value {
if (auto arg = v.dyn_cast<BlockArgument>())
if (arg.getOwner()->getParentOp() == forOp.getOperation())
return forOp.getOpOperandForRegionIterArg(arg).get();
return Value();
};
auto getYieldOp = [this](Value v) -> Value {
auto arg = v.cast<BlockArgument>();
unsigned yieldIdx = arg.getArgNumber() - forOp.getNumInductionVars();
return yieldOp.getOperand(yieldIdx);
};
for (triton::DotOp dot : dotsInFor) {
auto kSize = dot.a().getType().cast<RankedTensorType>().getShape()[1];
// Skip prefetching if kSize is less than prefetchWidth
if (kSize < prefetchWidth)
continue;
Value aSmem = getPrefetchSrc(dot.a());
Value bSmem = getPrefetchSrc(dot.b());
if (aSmem && bSmem) {
Value aHeaderDef = getIncomingOp(aSmem);
Value bHeaderDef = getIncomingOp(bSmem);
// Only prefetch loop arg
if (aHeaderDef && bHeaderDef) {
dots.insert(dot);
dot2aHeaderDef[dot] = aHeaderDef;
dot2bHeaderDef[dot] = bHeaderDef;
dot2aLoopArg[dot] = aSmem;
dot2bLoopArg[dot] = bSmem;
dot2aYield[dot] = getYieldOp(aSmem);
dot2bYield[dot] = getYieldOp(bSmem);
}
}
}
return success();
}
void Prefetcher::emitPrologue() {
OpBuilder builder(forOp);
for (Value dot : dots) {
Attribute dotEncoding =
dot.getType().cast<RankedTensorType>().getEncoding();
Value aPrefetched =
generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder);
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().a()] = aPrefetched;
Value bPrefetched =
generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder);
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().b()] = bPrefetched;
}
}
scf::ForOp Prefetcher::createNewForOp() {
OpBuilder builder(forOp);
SmallVector<Value> loopArgs;
for (auto v : forOp.getIterOperands())
loopArgs.push_back(v);
for (Value dot : dots) {
loopArgs.push_back(
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().a()]);
loopArgs.push_back(
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().b()]);
}
auto newForOp = builder.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), loopArgs);
auto largestPow2 = [](int64_t n) -> int64_t {
while ((n & (n - 1)) != 0)
n = n & (n - 1);
return n;
};
builder.setInsertionPointToStart(newForOp.getBody());
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping);
auto dot = dyn_cast<triton::DotOp>(&op);
if (dots.contains(dot)) {
Attribute dotEncoding =
dot.getType().cast<RankedTensorType>().getEncoding();
// prefetched dot
Operation *firstDot = builder.clone(*dot, mapping);
if (Value a = operand2headPrefetch.lookup(dot.a()))
firstDot->setOperand(
0, newForOp.getRegionIterArgForOpOperand(*a.use_begin()));
if (Value b = operand2headPrefetch.lookup(dot.b()))
firstDot->setOperand(
1, newForOp.getRegionIterArgForOpOperand(*b.use_begin()));
// remaining part
int64_t kOff = prefetchWidth;
int64_t kRem = dot.a().getType().cast<RankedTensorType>().getShape()[1] -
prefetchWidth;
Operation *prevDot = firstDot;
while (kRem != 0) {
int64_t kShape = largestPow2(kRem);
Value aRem =
generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false,
dotEncoding, builder, kOff, kShape);
Value bRem =
generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false,
dotEncoding, builder, kOff, kShape);
newOp = builder.clone(*dot, mapping);
newOp->setOperand(0, aRem);
newOp->setOperand(1, bRem);
newOp->setOperand(2, prevDot->getResult(0));
prevDot = newOp;
kOff += kShape;
kRem -= kShape;
}
}
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
}
// prefetch next iteration
SmallVector<Value> yieldValues;
for (Value v : forOp.getBody()->getTerminator()->getOperands())
yieldValues.push_back(mapping.lookup(v));
for (Value dot : dots) {
Attribute dotEncoding =
dot.getType().cast<RankedTensorType>().getEncoding();
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2aYield[dot]), 0,
true, dotEncoding, builder));
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2bYield[dot]), 1,
true, dotEncoding, builder));
}
// Update ops of yield
builder.create<scf::YieldOp>(yieldOp.getLoc(), yieldValues);
return newForOp;
}
struct PrefetchPass : public TritonGPUPrefetchBase<PrefetchPass> {
void runOnOperation() override {
getOperation()->walk([&](scf::ForOp forOp) {
Prefetcher prefetcher(forOp);
if (prefetcher.initialize().failed())
return;
prefetcher.emitPrologue();
scf::ForOp newForOp = prefetcher.createNewForOp();
// replace the original loop
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
forOp->erase();
});
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUPrefetchPass() {
return std::make_unique<PrefetchPass>();
}

View File

@@ -0,0 +1,103 @@
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <numeric>
using namespace mlir;
using namespace mlir::triton::gpu;
//
// TypeConverter
//
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numWarps)
: context(context), numWarps(numWarps) {
// TODO: how does MLIR pick the right conversion?
addConversion([](Type type) { return type; });
addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
// types with encoding are already in the right format
// TODO: check for layout encodings specifically
if (tensorType.getEncoding())
return tensorType;
// pessimistic values for attributes:
// - 1 element per thread
// - order = arange(rank)
ArrayRef<int64_t> shape = tensorType.getShape();
int rank = shape.size();
llvm::SmallVector<unsigned> order(rank);
std::iota(order.begin(), order.end(), 0);
llvm::SmallVector<unsigned> sizePerThread(rank, 1);
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
this->context, shape, sizePerThread, order, this->numWarps);
return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
});
//
// Materializations
//
// This will be called when (newArgType != origArgType)
// This will create newArg, and map(origArg, newArg)
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) {
llvm_unreachable("Argument rematerialization not implemented");
return llvm::None;
});
// If the origValue still has live user(s), use this to
// convert origValue to newValue
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
llvm_unreachable("Source rematerialization not implemented");
return llvm::None;
});
// This will be called when (desiredType != newOperandType)
// where, desiredType = typeConverter->convertType(origType)
// NOTE: only for remapped values.
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
auto cast =
builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType, inputs);
return Optional<Value>(cast.getResult());
// return Optional<Value>(cast.getResult(0));
// llvm_unreachable("Not implemented");
// return llvm::None;
});
}
//
// TritonGPUConversion
//
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context) {
// TODO: we should also verify ops of TritonGPUDialect
addLegalDialect<triton::gpu::TritonGPUDialect>();
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect, math::MathDialect,
triton::TritonDialect, StandardOpsDialect,
scf::SCFDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
Attribute aEncoding =
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding =
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding && aEncoding.isa<triton::gpu::DotOperandEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
return true;
return false;
});
}