[Triton-MLIR] tt.dot operands now must have DotOperand layout; also added prefetch pass prototype (#712)

Co-authored-by: Jokeren <kerenzhou@openai.com>
Co-authored-by: Phil Tillet <phil@openai.com>
Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Da Yan
2022-11-10 13:57:27 +08:00
committed by GitHub
parent 8832e32683
commit 4946167241
29 changed files with 1227 additions and 507 deletions

View File

@@ -1,3 +1,4 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
@@ -11,6 +12,7 @@
//===----------------------------------------------------------------------===//
using namespace mlir;
namespace ttg = triton::gpu;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
@@ -24,6 +26,7 @@ static Type getI1SameShape(Value v) {
}
namespace {
class LoopPipeliner {
/// cache forOp we are working on
scf::ForOp forOp;
@@ -37,6 +40,8 @@ class LoopPipeliner {
DenseMap<Value, Value> loadsMapping;
/// load => buffer
DenseMap<Value, Value> loadsBuffer;
/// load => buffer type (with shared layout after swizzling)
DenseMap<Value, RankedTensorType> loadsBufferType;
/// load => buffer at stage N
DenseMap<Value, SmallVector<Value>> loadStageBuffer;
/// load => after extract
@@ -67,8 +72,11 @@ class LoopPipeliner {
Value lookupOrDefault(Value origin, int stage);
/// returns a empty buffer of size <numStages, ...>
triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op,
OpBuilder &builder);
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
/// compute type of shared buffers (with swizzled shared layouts)
RankedTensorType getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc,
RankedTensorType tensorType);
public:
LoopPipeliner(scf::ForOp forOp, int numStages)
@@ -128,25 +136,82 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
}
}
triton::gpu::AllocTensorOp
LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) {
ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
OpBuilder &builder) {
// allocate a buffer for each pipelined tensor
// shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16>
Value convertLayout = loadsMapping[op->getResult(0)];
if (auto tensorType = convertLayout.getType().dyn_cast<RankedTensorType>()) {
SmallVector<int64_t> shape(tensorType.getShape().begin(),
tensorType.getShape().end());
shape.insert(shape.begin(), numStages);
Type elementType = tensorType.getElementType();
// The encoding of the buffer is similar to the original tensor
Attribute encoding = tensorType.getEncoding();
auto bufferType = RankedTensorType::get(shape, elementType, encoding);
return builder.create<triton::gpu::AllocTensorOp>(convertLayout.getLoc(),
bufferType);
return builder.create<ttg::AllocTensorOp>(
convertLayout.getLoc(), loadsBufferType[op->getResult(0)]);
}
llvm_unreachable("Async copy's return should be of RankedTensorType");
}
// TODO: I copied the code from Swizzle.cpp. Should find a way to unify the
// code path.
// Swizzle has to be performed before pipeline for now. If we do swizzle
// after pipeline, we need to propagate the swizzled layout to all
// operands that is an alias of the swizzled tensor. The alias analysis
// component maybe helpful for this purpose.
RankedTensorType
LoopPipeliner::getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc,
RankedTensorType ty) {
int opIdx = dotOpEnc.getOpIdx();
int vec = 1;
int maxPhase = 1;
int perPhase = 1;
llvm::SmallVector<unsigned> order;
if (auto mmaEnc = dotOpEnc.getParent().dyn_cast<ttg::MmaEncodingAttr>()) {
// Only support row major for now
// TODO(Keren): check why column major code crashes
order = {1, 0};
int version = mmaEnc.getVersion();
auto tyEncoding = ty.getEncoding().cast<ttg::BlockedEncodingAttr>();
// number of rows per phase
perPhase = 128 / (ty.getShape()[order[0]] *
(ty.getElementType().getIntOrFloatBitWidth() / 8));
perPhase = std::max<int>(perPhase, 1);
// index of the inner dimension in `order`
unsigned inner = (opIdx == 0) ? 0 : 1;
if (version == 1) {
maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
// TODO: handle rep (see
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
} else if (version == 2) {
auto eltTy = ty.getElementType();
std::vector<size_t> matShape = {8, 8,
2 * 64 / eltTy.getIntOrFloatBitWidth()};
// for now, disable swizzle when using transposed int8 tensor cores
if (ty.getElementType().isInteger(8) && order[0] == inner)
perPhase = 1;
else {
if (opIdx == 0) { // compute swizzling for A operand
vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m
int mmaStride = order[0] == 1 ? matShape[0] : matShape[2];
maxPhase = mmaStride / perPhase;
} else if (opIdx == 1) { // compute swizzling for B operand
vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k
int mmaStride = order[0] == 1 ? matShape[2] : matShape[1];
maxPhase = mmaStride / perPhase;
} else
llvm_unreachable("invalid operand index");
}
} else // version not in [1, 2]
llvm_unreachable("unsupported swizzling for provided MMA version");
} else { // If the layout of dot is not mma, we don't need to swizzle
auto blockedEnc = dotOpEnc.getParent().cast<ttg::BlockedEncodingAttr>();
order = llvm::SmallVector<unsigned>(blockedEnc.getOrder().begin(),
blockedEnc.getOrder().end());
}
auto newEncoding = ttg::SharedEncodingAttr::get(ty.getContext(), vec,
perPhase, maxPhase, order);
SmallVector<int64_t> bufferShape(ty.getShape().begin(), ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
return RankedTensorType::get(bufferShape, ty.getElementType(), newEncoding);
}
/// A load instruction can be pipelined if:
/// - the load doesn't depend on any other loads (after loop peeling)
/// - (?) this load is not a loop-invariant value (we should run LICM before
@@ -186,19 +251,21 @@ LogicalResult LoopPipeliner::initialize() {
}
}
// For now, we only pipeline loads that have one covert_layout (to smem) use
// We only pipeline loads that have one covert_layout (to dot_op) use
// TODO: lift this constraint in the future
if (isCandiate && loadOp.getResult().hasOneUse()) {
isCandiate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout =
llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
if (auto dotOpEnc = tensorType.getEncoding()
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
isCandiate = true;
loadsMapping[loadOp] = convertLayout;
loadsBufferType[loadOp] = getSwizzleType(
dotOpEnc, loadOp.getType().cast<RankedTensorType>());
}
}
}
@@ -238,6 +305,9 @@ void LoopPipeliner::emitPrologue() {
setValueMapping(arg, operand.get(), 0);
}
// helper to construct int attribute
auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); };
// prologue from [0, numStage-1)
Value iv = forOp.getLowerBound();
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
@@ -330,14 +400,15 @@ void LoopPipeliner::emitPrologue() {
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
} // for (int stage = 0; stage < numStages - 1; ++stage)
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
// async.wait & extract_slice
builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(),
loads.size() * (numStages - 2));
builder.create<ttg::AsyncWaitOp>(loads[0].getLoc(),
loads.size() * (numStages - 2));
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (Value loadOp : loads) {
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
sliceType =
RankedTensorType::get(sliceType.getShape(), sliceType.getElementType(),
loadsBufferType[loadOp].getEncoding());
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
SmallVector<OpFoldResult>{intAttr(0), intAttr(0), intAttr(0)},
@@ -366,6 +437,7 @@ void LoopPipeliner::emitEpilogue() {
scf::ForOp LoopPipeliner::createNewForOp() {
OpBuilder builder(forOp);
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
// order of new args:
// (original args),
@@ -477,8 +549,6 @@ scf::ForOp LoopPipeliner::createNewForOp() {
extractSliceIndex = builder.create<arith::IndexCastOp>(
extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex);
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// update loading mask
@@ -508,6 +578,9 @@ scf::ForOp LoopPipeliner::createNewForOp() {
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
nextBuffers.push_back(insertAsyncOp);
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
sliceType = RankedTensorType::get(sliceType.getShape(),
sliceType.getElementType(),
loadsBufferType[loadOp].getEncoding());
nextOp = builder.create<tensor::ExtractSliceOp>(
op->getLoc(), sliceType, insertAsyncOp,
SmallVector<OpFoldResult>{extractSliceIndex, intAttr(0), intAttr(0)},
@@ -534,8 +607,37 @@ scf::ForOp LoopPipeliner::createNewForOp() {
}
}
{
OpBuilder::InsertionGuard guard(builder);
for (Operation &op : *newForOp.getBody()) {
if (auto dotOp = llvm::dyn_cast<triton::DotOp>(&op)) {
builder.setInsertionPoint(&op);
auto dotType = dotOp.getType().cast<RankedTensorType>();
Value a = dotOp.a();
Value b = dotOp.b();
auto layoutCast = [&](Value dotOperand, int opIdx) -> Value {
auto tensorType = dotOperand.getType().cast<RankedTensorType>();
if (!tensorType.getEncoding().isa<ttg::DotOperandEncodingAttr>()) {
auto newEncoding = ttg::DotOperandEncodingAttr::get(
tensorType.getContext(), opIdx, dotType.getEncoding());
auto newType =
RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), newEncoding);
return builder.create<ttg::ConvertLayoutOp>(dotOperand.getLoc(),
newType, dotOperand);
}
return dotOperand;
};
a = layoutCast(a, 0);
b = layoutCast(b, 1);
dotOp->setOperand(0, a);
dotOp->setOperand(1, b);
}
}
}
// async.wait & extract_slice
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(
Operation *asyncWait = builder.create<ttg::AsyncWaitOp>(
loads[0].getLoc(), loads.size() * (numStages - 2));
for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) {
// move extract_slice after asyncWait