Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)

This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
Philippe Tillet
2022-12-21 01:30:50 -08:00
committed by GitHub
parent 8650b4d1cb
commit 20100a7254
285 changed files with 26312 additions and 50143 deletions

67
lib/Analysis/Alias.cpp Normal file
View File

@@ -0,0 +1,67 @@
#include "triton/Analysis/Alias.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir {
AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) {
if (lhs == rhs)
return lhs;
AliasInfo ret;
for (auto value : lhs.allocs) {
ret.insert(value);
}
for (auto value : rhs.allocs) {
ret.insert(value);
}
return ret;
}
ChangeResult SharedMemoryAliasAnalysis::visitOperation(
Operation *op, ArrayRef<LatticeElement<AliasInfo> *> operands) {
AliasInfo aliasInfo;
bool pessimistic = true;
if (maybeSharedAllocationOp(op)) {
// These ops may allocate a new shared memory buffer.
auto result = op->getResult(0);
// FIXME(Keren): extract and insert are always alias for now
if (isa<tensor::ExtractSliceOp, triton::TransOp>(op)) {
// extract_slice %src
aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;
} else if (isa<tensor::InsertSliceOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op)) {
// insert_slice_async %src, %dst, %index
// insert_slice %src into %dst[%offsets]
aliasInfo = AliasInfo(operands[1]->getValue());
pessimistic = false;
} else if (isSharedEncoding(result)) {
aliasInfo.insert(result);
pessimistic = false;
}
}
if (pessimistic) {
return markAllPessimisticFixpoint(op->getResults());
}
// Join all lattice elements
ChangeResult result = ChangeResult::NoChange;
for (Value value : op->getResults()) {
result |= getLatticeElement(value).join(aliasInfo);
}
return result;
}
AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) {
// TODO: implement
return AliasResult::MayAlias;
}
ModRefResult SharedMemoryAliasAnalysis::getModRef(Operation *op,
Value location) {
// TODO: implement
return ModRefResult::getModAndRef();
}
} // namespace mlir

476
lib/Analysis/Allocation.cpp Normal file
View File

@@ -0,0 +1,476 @@
#include "triton/Analysis/Allocation.h"
#include "mlir/Analysis/Liveness.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "triton/Analysis/Alias.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <limits>
#include <numeric>
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
namespace mlir {
//===----------------------------------------------------------------------===//
// Shared Memory Allocation Analysis
//===----------------------------------------------------------------------===//
namespace triton {
// Bitwidth of pointers
constexpr int kPtrBitWidth = 64;
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) {
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
auto srcDotLayout = srcLayout.dyn_cast<DotOperandEncodingAttr>();
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
auto dstDotLayout = dstLayout.dyn_cast<DotOperandEncodingAttr>();
assert(!(srcMmaLayout && dstMmaLayout) &&
"Unexpected mma -> mma layout conversion");
// mma or dot layout does not have an order, so the order depends on the
// layout of the other operand.
auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout)
: getOrder(srcLayout);
auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout)
: getOrder(dstLayout);
return {inOrd, outOrd};
}
SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec) {
auto srcTy = op.src().getType().cast<RankedTensorType>();
auto dstTy = op.result().getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
assert(srcLayout && dstLayout &&
"Unexpect layout in getScratchConfigForCvtLayout()");
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]];
unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]];
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
auto srcShapePerCTA = getShapePerCTA(srcLayout);
auto dstShapePerCTA = getShapePerCTA(dstLayout);
unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank);
unsigned pad = std::max(inVec, outVec);
for (unsigned d = 0; d < rank; ++d) {
paddedRepShape[d] =
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
}
if (rank == 1)
return paddedRepShape;
unsigned paddedDim = 1;
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
paddedDim = dstBlockedLayout.getOrder()[0];
}
paddedRepShape[paddedDim] += pad;
return paddedRepShape;
}
// TODO: extend beyond scalars
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
SmallVector<unsigned> smemShape;
if (op.ptr().getType().isa<RankedTensorType>()) {
// do nothing or just assert because shared memory is not used in tensor up
// to now
} else {
// need only bytes for scalar
// always vec = 1 and elemsPerThread = 1 for scalar?
smemShape.push_back(1);
}
return smemShape;
}
SmallVector<unsigned> getScratchConfigForAtomicCAS(triton::AtomicCASOp op) {
return SmallVector<unsigned>{1};
}
class AllocationAnalysis {
public:
AllocationAnalysis(Operation *operation, Allocation *allocation)
: operation(operation), allocation(allocation) {
run();
}
private:
using BufferT = Allocation::BufferT;
/// Value -> Liveness Range
/// Use MapVector to ensure determinism.
using BufferRangeMapT = llvm::MapVector<BufferT *, Interval<size_t>>;
/// Nodes -> Nodes
using GraphT = DenseMap<BufferT *, DenseSet<BufferT *>>;
void run() {
getValuesAndSizes();
resolveLiveness();
computeOffsets();
}
/// Initializes explicitly defined shared memory values for a given operation.
void getExplicitValueSize(Operation *op) {
// Values returned from scf.yield will not be allocated even though they
// have the shared encoding.
// For example: %a = scf.if -> yield
// %a must be allocated elsewhere by other operations.
// FIXME(Keren): extract and insert are always alias for now
if (!maybeSharedAllocationOp(op) || maybeAliasOp(op)) {
return;
}
for (Value result : op->getResults()) {
if (isSharedEncoding(result)) {
// Bytes could be a different value once we support padding or other
// allocation policies.
auto tensorType = result.getType().dyn_cast<RankedTensorType>();
auto bytes = tensorType.getNumElements() *
tensorType.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, bytes);
}
}
}
/// Initializes temporary shared memory for a given operation.
void getScratchValueSize(Operation *op) {
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
auto srcEncoding = srcTy.getEncoding();
auto dstEncoding = dstTy.getEncoding();
if (srcEncoding.isa<SharedEncodingAttr>() ||
dstEncoding.isa<SharedEncodingAttr>()) {
// Conversions from/to shared memory do not need scratch memory.
return;
}
// ConvertLayoutOp with both input/output non-shared_layout
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
// also possible to realize it with other approaches in restricted
// conditions, such as warp-shuffle
unsigned inVec = 0;
unsigned outVec = 0;
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto bytes =
srcTy.getElementType().isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory
// make it explicit for readability
if (value.getType().dyn_cast<RankedTensorType>()) {
// nothing to do
} else {
auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto elemTy =
value.getType().cast<triton::PointerType>().getPointeeType();
auto bytes =
elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
auto value = op->getOperand(0);
auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto elemTy =
value.getType().cast<triton::PointerType>().getPointeeType();
auto bytes = elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
}
void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) {
LatticeElement<AliasInfo> *latticeElement =
analysis.lookupLatticeElement(value);
if (latticeElement) {
auto &info = latticeElement->getValue();
if (!info.getAllocs().empty()) {
for (auto alloc : info.getAllocs()) {
allocation->addAlias(value, alloc);
}
}
}
}
/// Extract all shared memory values and their sizes
void getValuesAndSizes() {
// Get the alloc values
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
getExplicitValueSize(op);
getScratchValueSize(op);
});
// Get the alias values
SharedMemoryAliasAnalysis aliasAnalysis(operation->getContext());
aliasAnalysis.run(operation);
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
for (auto operand : op->getOperands()) {
getValueAlias(operand, aliasAnalysis);
}
for (auto value : op->getResults()) {
getValueAlias(value, aliasAnalysis);
}
});
}
/// Computes the liveness range of the allocated value.
/// Each buffer is allocated only once.
void resolveExplicitBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
for (auto valueBufferIter : allocation->valueBuffer) {
auto value = valueBufferIter.first;
auto *buffer = valueBufferIter.second;
bufferRange[buffer] = getLiveness(value);
}
}
/// Extends the liveness range by unionizing the liveness range of the aliased
/// values because each allocated buffer could be an alias of others, if block
/// arguments are involved.
void resolveAliasBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
for (auto aliasBufferIter : allocation->aliasBuffer) {
auto value = aliasBufferIter.first;
auto buffers = aliasBufferIter.second;
auto range = getLiveness(value);
for (auto *buffer : buffers) {
auto minId = range.start();
auto maxId = range.end();
if (bufferRange.count(buffer)) {
// Extend the allocated buffer's range
minId = std::min(minId, bufferRange[buffer].start());
maxId = std::max(maxId, bufferRange[buffer].end());
}
bufferRange[buffer] = Interval(minId, maxId);
}
}
}
/// Computes the liveness range of scratched buffers.
/// Some operations may have a temporary buffer that is not explicitly
/// allocated, but is used to store intermediate results.
void resolveScratchBufferLiveness(
const DenseMap<Operation *, size_t> &operationId) {
// Analyze liveness of scratch buffers
for (auto opScratchIter : allocation->opScratch) {
// Any scratch memory's live range is the current operation's live
// range.
auto *op = opScratchIter.first;
auto *buffer = opScratchIter.second;
bufferRange.insert({buffer, Interval(operationId.lookup(op),
operationId.lookup(op) + 1)});
}
}
/// Resolves liveness of all values involved under the root operation.
void resolveLiveness() {
// In the SCF dialect, we always have a sequentially nested structure of
// blocks
DenseMap<Operation *, size_t> operationId;
operation->walk<WalkOrder::PreOrder>(
[&](Operation *op) { operationId[op] = operationId.size(); });
// Analyze liveness of explicit buffers
Liveness liveness(operation);
auto getValueLivenessRange = [&](Value value) {
auto liveOperations = liveness.resolveLiveness(value);
auto minId = std::numeric_limits<size_t>::max();
auto maxId = std::numeric_limits<size_t>::min();
std::for_each(liveOperations.begin(), liveOperations.end(),
[&](Operation *liveOp) {
if (operationId[liveOp] < minId) {
minId = operationId[liveOp];
}
if ((operationId[liveOp] + 1) > maxId) {
maxId = operationId[liveOp] + 1;
}
});
return Interval(minId, maxId);
};
resolveExplicitBufferLiveness(getValueLivenessRange);
resolveAliasBufferLiveness(getValueLivenessRange);
resolveScratchBufferLiveness(operationId);
}
/// Computes the shared memory offsets for all related values.
/// Paper: Algorithms for Compile-Time Memory Optimization
/// (https://www.cs.utexas.edu/users/harrison/papers/compile-time.pdf)
void computeOffsets() {
SmallVector<BufferT *> buffers;
for (auto bufferIter : bufferRange) {
buffers.emplace_back(bufferIter.first);
}
DenseMap<BufferT *, size_t> bufferStart;
calculateStarts(buffers, bufferStart);
GraphT interference;
buildInterferenceGraph(buffers, bufferStart, interference);
allocate(buffers, bufferStart, interference);
}
/// Computes the initial shared memory offsets.
void calculateStarts(const SmallVector<BufferT *> &buffers,
DenseMap<BufferT *, size_t> &bufferStart) {
// v = values in shared memory
// t = triplet of (size, start, end)
// shared memory space
// -
// | *******t4
// | /|\ v2 inserts t4, t5, and t6
// | |
// | ******t5 ************t6
// | ^^^^^v2^^^^^^
// | | *********************t2
// | \|/ v2 erases t1
// | ******t1 ^^^^^^^^^v1^^^^^^^^^ ************t3
// |---------------------------------------------| liveness range
// 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ...
/// Start -> Liveness Range
using TripleMapT = std::multimap<size_t, Interval<size_t>>;
TripleMapT tripleMap;
tripleMap.insert(std::make_pair(0, Interval<size_t>()));
SmallVector<BufferT *> xBuffers = buffers;
while (!xBuffers.empty()) {
auto tripleIt = tripleMap.begin();
auto size = tripleIt->first;
auto range = tripleIt->second;
tripleMap.erase(tripleIt);
auto bufferIt =
std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) {
auto xRange = bufferRange[buffer];
bool res = xRange.intersects(range);
for (auto val : tripleMap)
res = res && !val.second.intersects(xRange);
return res;
});
if (bufferIt != xBuffers.end()) {
auto buffer = *bufferIt;
auto xSize = buffer->size;
auto xRange = bufferRange.lookup(buffer);
bufferStart[buffer] = size;
tripleMap.insert(
{size + xSize, Interval{std::max(range.start(), xRange.start()),
std::min(range.end(), xRange.end())}});
if (range.start() < xRange.start())
tripleMap.insert({size, Interval{range.start(), xRange.end()}});
if (xRange.end() < range.end())
tripleMap.insert({size, Interval{xRange.start(), range.end()}});
xBuffers.erase(bufferIt);
}
}
}
/// Builds a graph of all shared memory values. Edges are created between
/// shared memory values that are overlapping.
void buildInterferenceGraph(const SmallVector<BufferT *> &buffers,
const DenseMap<BufferT *, size_t> &bufferStart,
GraphT &interference) {
for (auto x : buffers) {
for (auto y : buffers) {
if (x == y)
continue;
auto xStart = bufferStart.lookup(x);
auto yStart = bufferStart.lookup(y);
auto xSize = x->size;
auto ySize = y->size;
Interval xSizeRange = {xStart, xStart + xSize};
Interval ySizeRange = {yStart, yStart + ySize};
auto xOpRange = bufferRange.lookup(x);
auto yOpRange = bufferRange.lookup(y);
if (xOpRange.intersects(yOpRange) &&
xSizeRange.intersects(ySizeRange)) {
interference[x].insert(y);
}
}
}
}
/// Finalizes shared memory offsets considering interference.
void allocate(const SmallVector<BufferT *> &buffers,
const DenseMap<BufferT *, size_t> &bufferStart,
const GraphT &interference) {
// First-fit graph coloring
// Neighbors are nodes that interfere with each other.
// We color a node by finding the index of the first available
// non-neighboring node or the first neighboring node without any color.
// Nodes with the same color do not interfere with each other.
DenseMap<BufferT *, int> colors;
for (auto value : buffers) {
colors[value] = (value == buffers[0]) ? 0 : -1;
}
SmallVector<bool> available(buffers.size());
for (auto x : buffers) {
std::fill(available.begin(), available.end(), true);
for (auto y : interference.lookup(x)) {
int color = colors[y];
if (color >= 0) {
available[color] = false;
}
}
auto it = std::find(available.begin(), available.end(), true);
colors[x] = std::distance(available.begin(), it);
}
// Finalize allocation
// color0: [0, 7), [0, 8), [0, 15) -> [0, 7), [0, 8), [0, 15)
// color1: [7, 9) -> [0 + 1 * 15, 9 + 1 * 15) -> [15, 24)
// color2: [8, 12) -> [8 + 2 * 15, 12 + 2 * 15) -> [38, 42)
// TODO(Keren): We are wasting memory here.
// Nodes with color2 can actually start with 24.
for (auto x : buffers) {
size_t adj = 0;
for (auto y : interference.lookup(x)) {
adj = std::max(adj, bufferStart.lookup(y) + y->size);
}
x->offset = bufferStart.lookup(x) + colors.lookup(x) * adj;
allocation->sharedMemorySize =
std::max(allocation->sharedMemorySize, x->offset + x->size);
}
}
private:
Operation *operation;
Allocation *allocation;
BufferRangeMapT bufferRange;
};
} // namespace triton
void Allocation::run() { triton::AllocationAnalysis(getOperation(), this); }
} // namespace mlir

321
lib/Analysis/AxisInfo.cpp Normal file
View File

@@ -0,0 +1,321 @@
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>
#include "triton/Analysis/AxisInfo.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// AxisInfo
//===----------------------------------------------------------------------===//
// Function for extended Euclidean Algorithm
static int gcd_impl(int a, int b, int *x, int *y) {
// Base Case
if (a == 0) {
*x = 0;
*y = 1;
return b;
}
int x1, y1; // To store results of recursive call
int gcd = gcd_impl(b % a, a, &x1, &y1);
// Update x and y using results of
// recursive call
*x = y1 - (b / a) * x1;
*y = x1;
return gcd;
}
static int gcd(int a, int b) {
int x, y;
return gcd_impl(a, b, &x, &y);
}
AxisInfo AxisInfo::getPessimisticValueState(Value value) {
size_t rank = 1;
if (TensorType ty = value.getType().dyn_cast<TensorType>())
rank = ty.getRank();
int divHint = 1;
BlockArgument blockArg = value.dyn_cast<BlockArgument>();
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
Operation *op = blockArg.getOwner()->getParentOp();
if (FuncOp fun = dyn_cast<FuncOp>(op)) {
Attribute attr =
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
if (attr)
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
} else if (auto fun = dyn_cast<LLVM::LLVMFuncOp>(op)) {
Attribute attr =
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
if (attr)
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
}
}
DimVectorT contiguity(rank, 1);
DimVectorT divisibility(rank, divHint);
DimVectorT constancy(rank, 1);
return AxisInfo(contiguity, divisibility, constancy);
}
// The gcd of both arguments for each dimension
AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
DimVectorT retContiguity;
DimVectorT retDivisibility;
DimVectorT retConstancy;
for (int d = 0; d < lhs.getRank(); ++d) {
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
retDivisibility.push_back(
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
retConstancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
}
return AxisInfo(retContiguity, retDivisibility, retConstancy);
}
//===----------------------------------------------------------------------===//
// AxisInfoAnalysis
//===----------------------------------------------------------------------===//
AxisInfo AxisInfoAnalysis::visitBinaryOp(
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy) {
int rank = lhsInfo.getRank();
AxisInfo::DimVectorT newContiguity;
AxisInfo::DimVectorT newDivisibility;
AxisInfo::DimVectorT newConstancy;
for (int d = 0; d < rank; ++d) {
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
}
return AxisInfo(newContiguity, newDivisibility, newConstancy);
}
ChangeResult AxisInfoAnalysis::visitOperation(
Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
AxisInfo curr;
// This preserves the input axes (e.g., cast):
if (llvm::isa<arith::ExtSIOp, arith::ExtUIOp, arith::TruncIOp,
triton::PtrToIntOp, triton::IntToPtrOp,
triton::gpu::ConvertLayoutOp>(op))
curr = operands[0]->getValue();
// Constant ranges
if (triton::MakeRangeOp make_range =
llvm::dyn_cast<triton::MakeRangeOp>(op)) {
int start = make_range.start();
int end = make_range.end();
AxisInfo::DimVectorT contiguity = {end - start};
AxisInfo::DimVectorT divisibility = {highestPowOf2Divisor(start)};
AxisInfo::DimVectorT constancy = {1};
curr = AxisInfo(contiguity, divisibility, constancy);
}
// Constant
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
auto intAttr = constant.getValue().dyn_cast<IntegerAttr>();
if (intAttr) {
size_t val = intAttr.getValue().getZExtValue();
curr = AxisInfo({1}, {highestPowOf2Divisor(val)}, {1});
}
// TODO: generalize to dense attr
auto splatAttr = constant.getValue().dyn_cast<SplatElementsAttr>();
if (splatAttr && splatAttr.getElementType().isInteger(32)) {
auto value = splatAttr.getSplatValue<int>();
TensorType ty = splatAttr.getType().cast<TensorType>();
curr = AxisInfo(
AxisInfo::DimVectorT(ty.getRank(), 1),
AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)),
AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()));
}
}
// TODO: refactor & complete binary ops
// Addition
if (llvm::isa<arith::AddIOp, triton::AddPtrOp>(op)) {
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)),
gcd(lhs.getConstancy(d), rhs.getContiguity(d)));
};
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// Multiplication
if (llvm::isa<arith::MulIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
return lhs.getDivisibility(d) * rhs.getDivisibility(d);
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// Remainder
if (llvm::isa<arith::RemSIOp, arith::RemUIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getContiguity(d), rhs.getDivisibility(d));
};
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
};
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// TODO: All other binary ops
if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// Splat
if (llvm::isa<triton::SplatOp>(op)) {
Type _retTy = *op->result_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy;
for (int d = 0; d < retTy.getRank(); ++d) {
contiguity.push_back(1);
divisibility.push_back(opInfo.getDivisibility(0));
constancy.push_back(retTy.getShape()[d]);
}
curr = AxisInfo(contiguity, divisibility, constancy);
}
// expandDims
if (auto expandDims = llvm::dyn_cast<triton::ExpandDimsOp>(op)) {
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
AxisInfo::DimVectorT constancy = opInfo.getConstancy();
contiguity.insert(contiguity.begin() + expandDims.axis(), 1);
divisibility.insert(divisibility.begin() + expandDims.axis(), 1);
constancy.insert(constancy.begin() + expandDims.axis(), 1);
curr = AxisInfo(contiguity, divisibility, constancy);
}
// Broadcast
if (llvm::isa<triton::BroadcastOp>(op)) {
Type _retTy = *op->result_type_begin();
Type _opTy = *op->operand_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
TensorType opTy = _opTy.cast<TensorType>();
ArrayRef<int64_t> retShape = retTy.getShape();
ArrayRef<int64_t> opShape = opTy.getShape();
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy;
for (int d = 0; d < retTy.getRank(); ++d) {
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
divisibility.push_back(opInfo.getDivisibility(d));
constancy.push_back(opShape[d] == 1 ? retShape[d]
: opInfo.getConstancy(d));
}
curr = AxisInfo(contiguity, divisibility, constancy);
}
// CmpI
if ((llvm::dyn_cast<arith::CmpIOp>(op) ||
llvm::dyn_cast<triton::gpu::CmpIOp>(op)) &&
op->getResult(0).getType().dyn_cast<TensorType>()) {
auto resTy = op->getResult(0).getType().cast<TensorType>();
short rank = resTy.getRank();
auto lhsInfo = operands[0]->getValue();
auto rhsInfo = operands[1]->getValue();
auto shape = resTy.getShape();
AxisInfo::DimVectorT contiguity, divisibility, constancy;
for (short d = 0; d < rank; ++d) {
if (rhsInfo.getConstancy(d) % lhsInfo.getContiguity(d) == 0 ||
rhsInfo.getConstancy(d) % lhsInfo.getConstancy(d))
constancy.push_back(
gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d)));
else
constancy.push_back(1);
divisibility.push_back(shape[d]);
contiguity.push_back(1);
}
curr = AxisInfo(contiguity, divisibility, constancy);
}
// UnrealizedConversionCast
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
// in the process of a PartialConversion, where UnrealizedConversionCast
// may exist
if (llvm::isa<mlir::UnrealizedConversionCastOp>(op)) {
curr = operands[0]->getValue();
}
if (curr.getRank() == 0) {
return markAllPessimisticFixpoint(op->getResults());
}
// join all lattice elements
ChangeResult result = ChangeResult::NoChange;
for (Value value : op->getResults()) {
result |= getLatticeElement(value).join(curr);
}
return result;
}
unsigned AxisInfoAnalysis::getPtrVectorSize(Value ptr) {
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
auto layout = tensorTy.getEncoding();
auto shape = tensorTy.getShape();
// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = triton::gpu::getOrder(layout);
unsigned align = getPtrAlignment(ptr);
unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]];
unsigned vec = std::min(align, contigPerThread);
vec = std::min<unsigned>(shape[order[0]], vec);
return vec;
}
unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
auto axisInfo = lookupLatticeElement(ptr)->getValue();
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getOrder(layout);
unsigned maxMultiple = axisInfo.getDivisibility(order[0]);
unsigned maxContig = axisInfo.getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
return alignment;
}
unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) {
auto tensorTy = mask.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
auto maskAxis = lookupLatticeElement(mask)->getValue();
auto alignment = std::max<unsigned>(maskAxis.getConstancy(maskOrder[0]), 1);
return alignment;
}
} // namespace mlir

View File

@@ -0,0 +1,10 @@
add_mlir_library(TritonAnalysis
AxisInfo.cpp
Allocation.cpp
Membar.cpp
Alias.cpp
Utility.cpp
DEPENDS
TritonGPUAttrDefsIncGen
)

137
lib/Analysis/Membar.cpp Normal file
View File

@@ -0,0 +1,137 @@
#include "triton/Analysis/Membar.h"
#include "triton/Analysis/Alias.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
namespace mlir {
void MembarAnalysis::run() {
auto *operation = allocation->getOperation();
RegionInfo regionInfo;
OpBuilder builder(operation);
dfsOperation(operation, &regionInfo, &builder);
}
void MembarAnalysis::dfsOperation(Operation *operation,
RegionInfo *parentRegionInfo,
OpBuilder *builder) {
transfer(operation, parentRegionInfo, builder);
if (operation->getNumRegions()) {
// If there's any nested regions, we need to visit them.
// scf.if and scf.else: two regions
// scf.if only: two regions
// scf.for: one region
RegionInfo curRegionInfo;
auto traverseRegions = [&]() -> auto{
for (auto &region : operation->getRegions()) {
// Copy the parent info as the current info.
RegionInfo regionInfo = *parentRegionInfo;
for (auto &block : region.getBlocks()) {
assert(region.getBlocks().size() == 1 &&
"Multiple blocks in a region is not supported");
for (auto &op : block.getOperations()) {
// Traverse the nested operation.
dfsOperation(&op, &regionInfo, builder);
}
}
curRegionInfo.join(regionInfo);
}
// Set the parent region info as the union of the nested region info.
*parentRegionInfo = curRegionInfo;
};
traverseRegions();
if (isa<scf::ForOp>(operation)) {
// scf.for can have two possible inputs: the init value and the
// previous iteration's result. Although we've applied alias analysis,
// there could be unsynced memory accesses on reused memories.
// For example, consider the following code:
// %1 = convert_layout %0: blocked -> shared
// ...
// gpu.barrier
// ...
// %5 = convert_layout %4 : shared -> dot
// %6 = tt.dot %2, %5
// scf.yield
//
// Though %5 could be released before scf.yield, it may shared the same
// memory with %1. So we actually have to insert a barrier before %1 to
// make sure the memory is synced.
traverseRegions();
}
}
}
void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
OpBuilder *builder) {
if (isa<scf::ForOp>(op) || isa<scf::IfOp>(op) || isa<scf::YieldOp>(op) ||
isa<tensor::ExtractSliceOp>(op) || isa<triton::gpu::AllocTensorOp>(op)) {
// Do not insert barriers before control flow operations and
// alloc/extract/insert
// alloc is an allocation op without memory write.
// FIXME(Keren): extract_slice is always alias for now
return;
}
if (isa<gpu::BarrierOp>(op)) {
// If the current op is a barrier, we sync previous reads and writes
regionInfo->sync();
return;
}
if (isa<triton::gpu::AsyncWaitOp>(op) &&
!isa<gpu::BarrierOp>(op->getNextNode())) {
// If the current op is an async wait and the next op is not a barrier we
// insert a barrier op and sync
regionInfo->sync();
OpBuilder::InsertionGuard g(*builder);
builder->setInsertionPointAfter(op);
builder->create<gpu::BarrierOp>(op->getLoc());
regionInfo->sync();
return;
}
RegionInfo curRegionInfo;
for (Value value : op->getOperands()) {
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId) {
if (isa<triton::gpu::InsertSliceAsyncOp>(op) ||
isa<tensor::InsertSliceOp>(op)) {
// FIXME(Keren): insert_slice and insert_slice_async are always alias
// for now
curRegionInfo.syncWriteBuffers.insert(bufferId);
} else {
// ConvertLayoutOp: shared memory -> registers
curRegionInfo.syncReadBuffers.insert(bufferId);
}
}
}
}
for (Value value : op->getResults()) {
// ConvertLayoutOp: registers -> shared memory
auto bufferId = allocation->getBufferId(value);
if (bufferId != Allocation::InvalidBufferId) {
curRegionInfo.syncWriteBuffers.insert(bufferId);
}
}
// Scratch buffer is considered as both shared memory write & read
auto bufferId = allocation->getBufferId(op);
if (bufferId != Allocation::InvalidBufferId) {
curRegionInfo.syncWriteBuffers.insert(bufferId);
curRegionInfo.syncReadBuffers.insert(bufferId);
}
if (regionInfo->isIntersected(curRegionInfo, allocation)) {
OpBuilder::InsertionGuard g(*builder);
builder->setInsertionPoint(op);
builder->create<gpu::BarrierOp>(op->getLoc());
regionInfo->sync();
}
// Update the region info, even if barrier is inserted, we have to maintain
// the current op's read/write buffers.
regionInfo->join(curRegionInfo);
}
} // namespace mlir

151
lib/Analysis/Utility.cpp Normal file
View File

@@ -0,0 +1,151 @@
#include "triton/Analysis/Utility.h"
#include "mlir/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir {
bool ReduceOpHelper::isFastReduction() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.axis();
return axis == triton::gpu::getOrder(srcLayout)[0];
}
unsigned ReduceOpHelper::getInterWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.axis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTA(srcLayout)[axis]);
}
unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.axis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarp(srcLayout)[axis]);
}
unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.axis();
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
}
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
auto axis = op.axis();
auto smemShape = convertType<unsigned>(getSrcShape());
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
return smemShape;
}
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto axis = op.axis();
SmallVector<SmallVector<unsigned>> smemShapes(3);
/// shared memory block0
smemShapes[0] = convertType<unsigned>(getSrcShape());
smemShapes[0][axis] = getInterWarpSize();
/// FIXME(Qingyi): This size is actually larger than required.
/// shared memory block1:
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
smemShapes[1].push_back(numWarps * 32);
return smemShapes;
}
unsigned ReduceOpHelper::getScratchSizeInBytes() {
unsigned elems = 0;
if (isFastReduction()) {
auto smemShapes = getScratchConfigsFast();
for (const auto &smemShape : smemShapes)
elems = std::max(elems, product<unsigned>(smemShape));
} else {
auto smemShape = getScratchConfigBasic();
elems = product<unsigned>(smemShape);
}
auto tensorType = op.operand().getType().cast<RankedTensorType>();
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;
if (triton::ReduceOp::withIndex(op.redOp()))
bytes += elems * sizeof(int32_t);
return bytes;
}
bool isSharedEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
auto encoding = tensorType.getEncoding();
return encoding && encoding.isa<triton::gpu::SharedEncodingAttr>();
}
return false;
}
bool maybeSharedAllocationOp(Operation *op) {
// TODO(Keren): This function can be replaced by adding
// MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to
// query the memory effects of the op.
auto *dialect = op->getDialect();
return dialect &&
(dialect->getTypeID() ==
mlir::TypeID::get<triton::gpu::TritonGPUDialect>() ||
dialect->getTypeID() == mlir::TypeID::get<triton::TritonDialect>() ||
dialect->getTypeID() ==
mlir::TypeID::get<arith::ArithmeticDialect>() ||
dialect->getTypeID() == mlir::TypeID::get<tensor::TensorDialect>());
}
bool maybeAliasOp(Operation *op) {
return isa<tensor::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
isa<tensor::InsertSliceOp>(op);
}
bool supportMMA(triton::DotOp op, int version) {
// Refer to mma section for the data type supported by Volta and Hopper
// Tensor Core in
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
auto aElemTy = op.a().getType().cast<RankedTensorType>().getElementType();
auto bElemTy = op.b().getType().cast<RankedTensorType>().getElementType();
if (aElemTy.isF32() && bElemTy.isF32()) {
return op.allowTF32() && version >= 2;
}
return supportMMA(op.a(), version) && supportMMA(op.b(), version);
}
bool supportMMA(Value value, int version) {
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
// We cannot get both the operand types(in TypeConverter), here we assume the
// types of both the operands are identical here.
assert((version == 1 || version == 2) &&
"Unexpected MMA layout version found");
auto elemTy = value.getType().cast<RankedTensorType>().getElementType();
return elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && version >= 2) ||
(elemTy.isInteger(8) && version >= 2);
}
Type getElementType(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return tensorType.getElementType();
return type;
}
std::string getValueOperandName(Value value, AsmState &state) {
std::string opName;
llvm::raw_string_ostream ss(opName);
value.printAsOperand(ss, state);
return opName;
}
} // namespace mlir