[Triton-MLIR] tt.dot operands now must have DotOperand layout; also added prefetch pass prototype (#712)

Co-authored-by: Jokeren <kerenzhou@openai.com>
Co-authored-by: Phil Tillet <phil@openai.com>
Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Da Yan
2022-11-10 13:57:27 +08:00
committed by GitHub
parent 8832e32683
commit 4946167241
29 changed files with 1227 additions and 507 deletions

View File

@@ -7,6 +7,7 @@ add_mlir_dialect_library(TritonGPUTransforms
CanonicalizeLoops.cpp
Combine.cpp
Pipeline.cpp
Prefetch.cpp
Swizzle.cpp
TritonGPUConversion.cpp

View File

@@ -12,21 +12,13 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <memory>
using namespace mlir;
static bool isSharedLayout(Value v) {
if (auto tensorType = v.getType().dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
return encoding.isa<triton::gpu::SharedEncodingAttr>();
}
return false;
}
namespace {
#include "TritonGPUCombine.inc"
@@ -37,7 +29,7 @@ namespace {
// convert(blocked, dot_operand) ->
// convert(blocked, mma) + convert(mma, dot_operand)
// if this value is itself the result of a dot operation
// this is a hueiristics to accomodate some pattern seen in fused attention
// this is a heuristic to accomodate some pattern seen in fused attention
// kernels.
// TODO: replace this by something more generic, i.e. layout-aware CSE
class DecomposeDotOperand : public mlir::RewritePattern {
@@ -59,9 +51,8 @@ public:
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
auto tmpType =
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
dstType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent());
triton::gpu::SharedEncodingAttr::get(
op->getContext(), 1, 1, 1, {1, 0}));
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
convert.getLoc(), tmpType, convert.getOperand());
auto newConvert = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -87,11 +78,12 @@ public:
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
return mlir::failure();
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto dstType = convert.getType().cast<RankedTensorType>();
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accomodate fused attention
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
return mlir::failure();
// if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
// return mlir::failure();
// convert to the same layout -- we can delete
if (op->getResultTypes() == op->getOperandTypes()) {
rewriter.replaceOp(op, op->getOperands());
@@ -122,8 +114,8 @@ public:
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
op, newType, insert_slice.src(), newArg.getResult(),
insert_slice.index(), insert_slice.mask(), insert_slice.other(),
insert_slice.cache(), insert_slice.evict(),
insert_slice.isVolatile(), insert_slice.axis());
insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(),
insert_slice.axis());
return mlir::success();
}
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
@@ -133,7 +125,10 @@ public:
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
auto resType = op->getResult(0).getType().cast<RankedTensorType>();
auto origResType = op->getResult(0).getType().cast<RankedTensorType>();
auto resType = RankedTensorType::get(
origResType.getShape(), origResType.getElementType(),
extract_slice.getType().cast<RankedTensorType>().getEncoding());
// Ensure that the new extract_slice op is placed in the same place as the
// old extract_slice op. Otherwise, the new extract_slice op may be placed
// after the async_wait op, which is not allowed.
@@ -148,8 +143,21 @@ public:
extract_slice.static_strides());
return mlir::success();
}
// cvt(type2, x)
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
auto argType = arg->getOperand(0).getType().cast<RankedTensorType>();
if (arg->getOperand(0).getDefiningOp() &&
!argType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
srcType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
!dstType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
return mlir::failure();
}
auto srcShared =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
if (srcShared && srcShared.getVec() > 1)
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, op->getResultTypes().front(), arg->getOperand(0));
return mlir::success();
@@ -253,8 +261,8 @@ public:
if (!op)
return mlir::failure();
// we don't want to rematerialize any conversion to/from shared
if (isSharedLayout(cvt->getResults()[0]) ||
isSharedLayout(cvt->getOperand(0)))
if (isSharedEncoding(cvt->getResults()[0]) ||
isSharedEncoding(cvt->getOperand(0)))
return mlir::failure();
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accomodate fused attention
@@ -325,7 +333,6 @@ public:
for (Operation *op : tmp)
sortedValues.push_back(op->getResult(0));
// llvm::outs() << "----\n";
BlockAndValueMapping mapping;
for (Value currOperand : sortedValues) {
// unpack information
@@ -346,7 +353,6 @@ public:
newOperand->moveAfter(currOperation);
mapping.map(currOperand, newOperand);
}
// llvm::outs() << cvt->getParentOfType<mlir::FuncOp>() << "\n";
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
return mlir::success();
}
@@ -356,8 +362,6 @@ public:
//
// -----------------------------------------------------------------------------
// int test = 0;
class MoveConvertOutOfLoop : public mlir::RewritePattern {
public:
MoveConvertOutOfLoop(mlir::MLIRContext *context)
@@ -435,9 +439,25 @@ public:
auto users = iterArg.value().getUsers();
// check first condition
SetVector<Type> cvtTargetTypes;
for (auto user : users)
if (isa<triton::gpu::ConvertLayoutOp>(user))
cvtTargetTypes.insert(user->getResults()[0].getType());
for (auto user : users) {
if (isa<triton::gpu::ConvertLayoutOp>(user)) {
auto newType =
user->getResults()[0].getType().cast<RankedTensorType>();
auto oldType = user->getOperand(0).getType().cast<RankedTensorType>();
if (oldType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
newType.getEncoding()
.isa<triton::gpu::DotOperandEncodingAttr>()) {
continue;
}
if (newType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
if (newType.getEncoding()
.cast<triton::gpu::SharedEncodingAttr>()
.getVec() == 1)
continue;
}
cvtTargetTypes.insert(newType);
}
}
if (cvtTargetTypes.size() != 1)
continue;
// TODO: check second condition
@@ -446,6 +466,7 @@ public:
continue;
}
// check
// llvm::outs() << "replacing " << iterArg.index() << "\n";
for (auto op : iterArg.value().getUsers()) {
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
if (!cvt)
@@ -597,10 +618,23 @@ public:
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
oldAcc.getLoc(), newRetType, oldAcc);
// convert output
Value a = dotOp.a();
Value b = dotOp.b();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding()));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding()));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1),
newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB());
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32(),
dotOp.transA(), dotOp.transB());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, oldRetType, newDot.getResult());
@@ -623,7 +657,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<SimplifyConversion>(context);
patterns.add<DecomposeDotOperand>(context);
// patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);
patterns.add<RematerializeForward>(context);
patterns.add<MoveConvertOutOfLoop>(context);

View File

@@ -1,3 +1,4 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
@@ -11,6 +12,7 @@
//===----------------------------------------------------------------------===//
using namespace mlir;
namespace ttg = triton::gpu;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
@@ -24,6 +26,7 @@ static Type getI1SameShape(Value v) {
}
namespace {
class LoopPipeliner {
/// cache forOp we are working on
scf::ForOp forOp;
@@ -37,6 +40,8 @@ class LoopPipeliner {
DenseMap<Value, Value> loadsMapping;
/// load => buffer
DenseMap<Value, Value> loadsBuffer;
/// load => buffer type (with shared layout after swizzling)
DenseMap<Value, RankedTensorType> loadsBufferType;
/// load => buffer at stage N
DenseMap<Value, SmallVector<Value>> loadStageBuffer;
/// load => after extract
@@ -67,8 +72,11 @@ class LoopPipeliner {
Value lookupOrDefault(Value origin, int stage);
/// returns a empty buffer of size <numStages, ...>
triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op,
OpBuilder &builder);
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
/// compute type of shared buffers (with swizzled shared layouts)
RankedTensorType getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc,
RankedTensorType tensorType);
public:
LoopPipeliner(scf::ForOp forOp, int numStages)
@@ -128,25 +136,82 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
}
}
triton::gpu::AllocTensorOp
LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) {
ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
OpBuilder &builder) {
// allocate a buffer for each pipelined tensor
// shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16>
Value convertLayout = loadsMapping[op->getResult(0)];
if (auto tensorType = convertLayout.getType().dyn_cast<RankedTensorType>()) {
SmallVector<int64_t> shape(tensorType.getShape().begin(),
tensorType.getShape().end());
shape.insert(shape.begin(), numStages);
Type elementType = tensorType.getElementType();
// The encoding of the buffer is similar to the original tensor
Attribute encoding = tensorType.getEncoding();
auto bufferType = RankedTensorType::get(shape, elementType, encoding);
return builder.create<triton::gpu::AllocTensorOp>(convertLayout.getLoc(),
bufferType);
return builder.create<ttg::AllocTensorOp>(
convertLayout.getLoc(), loadsBufferType[op->getResult(0)]);
}
llvm_unreachable("Async copy's return should be of RankedTensorType");
}
// TODO: I copied the code from Swizzle.cpp. Should find a way to unify the
// code path.
// Swizzle has to be performed before pipeline for now. If we do swizzle
// after pipeline, we need to propagate the swizzled layout to all
// operands that is an alias of the swizzled tensor. The alias analysis
// component maybe helpful for this purpose.
RankedTensorType
LoopPipeliner::getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc,
RankedTensorType ty) {
int opIdx = dotOpEnc.getOpIdx();
int vec = 1;
int maxPhase = 1;
int perPhase = 1;
llvm::SmallVector<unsigned> order;
if (auto mmaEnc = dotOpEnc.getParent().dyn_cast<ttg::MmaEncodingAttr>()) {
// Only support row major for now
// TODO(Keren): check why column major code crashes
order = {1, 0};
int version = mmaEnc.getVersion();
auto tyEncoding = ty.getEncoding().cast<ttg::BlockedEncodingAttr>();
// number of rows per phase
perPhase = 128 / (ty.getShape()[order[0]] *
(ty.getElementType().getIntOrFloatBitWidth() / 8));
perPhase = std::max<int>(perPhase, 1);
// index of the inner dimension in `order`
unsigned inner = (opIdx == 0) ? 0 : 1;
if (version == 1) {
maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
// TODO: handle rep (see
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
} else if (version == 2) {
auto eltTy = ty.getElementType();
std::vector<size_t> matShape = {8, 8,
2 * 64 / eltTy.getIntOrFloatBitWidth()};
// for now, disable swizzle when using transposed int8 tensor cores
if (ty.getElementType().isInteger(8) && order[0] == inner)
perPhase = 1;
else {
if (opIdx == 0) { // compute swizzling for A operand
vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m
int mmaStride = order[0] == 1 ? matShape[0] : matShape[2];
maxPhase = mmaStride / perPhase;
} else if (opIdx == 1) { // compute swizzling for B operand
vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k
int mmaStride = order[0] == 1 ? matShape[2] : matShape[1];
maxPhase = mmaStride / perPhase;
} else
llvm_unreachable("invalid operand index");
}
} else // version not in [1, 2]
llvm_unreachable("unsupported swizzling for provided MMA version");
} else { // If the layout of dot is not mma, we don't need to swizzle
auto blockedEnc = dotOpEnc.getParent().cast<ttg::BlockedEncodingAttr>();
order = llvm::SmallVector<unsigned>(blockedEnc.getOrder().begin(),
blockedEnc.getOrder().end());
}
auto newEncoding = ttg::SharedEncodingAttr::get(ty.getContext(), vec,
perPhase, maxPhase, order);
SmallVector<int64_t> bufferShape(ty.getShape().begin(), ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
return RankedTensorType::get(bufferShape, ty.getElementType(), newEncoding);
}
/// A load instruction can be pipelined if:
/// - the load doesn't depend on any other loads (after loop peeling)
/// - (?) this load is not a loop-invariant value (we should run LICM before
@@ -186,19 +251,21 @@ LogicalResult LoopPipeliner::initialize() {
}
}
// For now, we only pipeline loads that have one covert_layout (to smem) use
// We only pipeline loads that have one covert_layout (to dot_op) use
// TODO: lift this constraint in the future
if (isCandiate && loadOp.getResult().hasOneUse()) {
isCandiate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout =
llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
if (auto dotOpEnc = tensorType.getEncoding()
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
isCandiate = true;
loadsMapping[loadOp] = convertLayout;
loadsBufferType[loadOp] = getSwizzleType(
dotOpEnc, loadOp.getType().cast<RankedTensorType>());
}
}
}
@@ -238,6 +305,9 @@ void LoopPipeliner::emitPrologue() {
setValueMapping(arg, operand.get(), 0);
}
// helper to construct int attribute
auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); };
// prologue from [0, numStage-1)
Value iv = forOp.getLowerBound();
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
@@ -330,14 +400,15 @@ void LoopPipeliner::emitPrologue() {
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
} // for (int stage = 0; stage < numStages - 1; ++stage)
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
// async.wait & extract_slice
builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(),
loads.size() * (numStages - 2));
builder.create<ttg::AsyncWaitOp>(loads[0].getLoc(),
loads.size() * (numStages - 2));
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (Value loadOp : loads) {
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
sliceType =
RankedTensorType::get(sliceType.getShape(), sliceType.getElementType(),
loadsBufferType[loadOp].getEncoding());
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
SmallVector<OpFoldResult>{intAttr(0), intAttr(0), intAttr(0)},
@@ -366,6 +437,7 @@ void LoopPipeliner::emitEpilogue() {
scf::ForOp LoopPipeliner::createNewForOp() {
OpBuilder builder(forOp);
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
// order of new args:
// (original args),
@@ -477,8 +549,6 @@ scf::ForOp LoopPipeliner::createNewForOp() {
extractSliceIndex = builder.create<arith::IndexCastOp>(
extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex);
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// update loading mask
@@ -508,6 +578,9 @@ scf::ForOp LoopPipeliner::createNewForOp() {
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
nextBuffers.push_back(insertAsyncOp);
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
sliceType = RankedTensorType::get(sliceType.getShape(),
sliceType.getElementType(),
loadsBufferType[loadOp].getEncoding());
nextOp = builder.create<tensor::ExtractSliceOp>(
op->getLoc(), sliceType, insertAsyncOp,
SmallVector<OpFoldResult>{extractSliceIndex, intAttr(0), intAttr(0)},
@@ -534,8 +607,37 @@ scf::ForOp LoopPipeliner::createNewForOp() {
}
}
{
OpBuilder::InsertionGuard guard(builder);
for (Operation &op : *newForOp.getBody()) {
if (auto dotOp = llvm::dyn_cast<triton::DotOp>(&op)) {
builder.setInsertionPoint(&op);
auto dotType = dotOp.getType().cast<RankedTensorType>();
Value a = dotOp.a();
Value b = dotOp.b();
auto layoutCast = [&](Value dotOperand, int opIdx) -> Value {
auto tensorType = dotOperand.getType().cast<RankedTensorType>();
if (!tensorType.getEncoding().isa<ttg::DotOperandEncodingAttr>()) {
auto newEncoding = ttg::DotOperandEncodingAttr::get(
tensorType.getContext(), opIdx, dotType.getEncoding());
auto newType =
RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), newEncoding);
return builder.create<ttg::ConvertLayoutOp>(dotOperand.getLoc(),
newType, dotOperand);
}
return dotOperand;
};
a = layoutCast(a, 0);
b = layoutCast(b, 1);
dotOp->setOperand(0, a);
dotOp->setOperand(1, b);
}
}
}
// async.wait & extract_slice
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(
Operation *asyncWait = builder.create<ttg::AsyncWaitOp>(
loads[0].getLoc(), loads.size() * (numStages - 2));
for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) {
// move extract_slice after asyncWait

View File

@@ -0,0 +1,304 @@
//===----------------------------------------------------------------------===//
//
// This pass tries to prefetch operands (a and b) of tt.dot.
// Those ConvertLayoutOps will be lowered to shared memory loads.
//
// For example:
// %a: tensor<128x32xf16, #enc>
// scf.for %iv = ... iter_args(%a_arg = %a, ...) {
// %d = tt.dot %a_arg, %b, %c
// ...
// scf.yield %a_next, ...
// }
//
// will be translated to
//
// %a: tensor<128x32xf16, #enc>
// %a_tmp = tensor.extract_slice %a[0, 0] [128, 16]
// %a_prefetch = triton_gpu.convert_layout %a_tmp
// scf.for %iv = ... iter_args(%a_buf = %a, ..., %a_prefetch_arg = %a_prefetch)
// {
// %x = tt.dot %a_arg, %b, %c
// %a_tmp_rem = tensor.extract_slice %a_buf[0, 16] [128, 16]
// %a_prefetch_next = triton_gpu.convert_layout %a_tmp_rem
// ...
// scf.yield %next_a, ..., %a_prefetch_next
// }
//===----------------------------------------------------------------------===//
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
using namespace mlir;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace {
class Prefetcher {
/// cache the ForOp we are working on
scf::ForOp forOp;
/// cache the YieldOp of this ForOp
scf::YieldOp yieldOp;
///
// TODO: add a hook to infer prefetchWidth
unsigned prefetchWidth = 16;
/// dots to be prefetched
SetVector<Value> dots;
/// dot => dot operand
DenseMap<Value, Value> dot2aLoopArg;
DenseMap<Value, Value> dot2aHeaderDef;
DenseMap<Value, Value> dot2bLoopArg;
DenseMap<Value, Value> dot2bHeaderDef;
DenseMap<Value, Value> dot2aYield;
DenseMap<Value, Value> dot2bYield;
/// operand => defining
DenseMap<Value, Value> operand2headPrefetch;
LogicalResult isForOpOperand(Value v);
Value generatePrefetch(Value v, unsigned opIdx, bool isPrefetch,
Attribute dotEncoding, OpBuilder &builder,
llvm::Optional<int64_t> offsetK = llvm::None,
llvm::Optional<int64_t> shapeK = llvm::None);
public:
Prefetcher() = delete;
Prefetcher(scf::ForOp forOp) : forOp(forOp) {
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
}
LogicalResult initialize();
void emitPrologue();
scf::ForOp createNewForOp();
};
Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrefetch,
Attribute dotEncoding, OpBuilder &builder,
llvm::Optional<int64_t> offsetK,
llvm::Optional<int64_t> shapeK) {
// opIdx: 0 => a, 1 => b
auto type = v.getType().cast<RankedTensorType>();
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
SmallVector<int64_t> offset{0, 0};
Type elementType = type.getElementType();
auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); };
// k => (prefetchWidth, k - prefetchWidth)
int64_t kIdx = opIdx == 0 ? 1 : 0;
offset[kIdx] = isPrefetch ? 0 : prefetchWidth;
shape[kIdx] = isPrefetch ? prefetchWidth : (shape[kIdx] - prefetchWidth);
if (shapeK)
shape[kIdx] = *shapeK;
if (offsetK)
offset[kIdx] = *offsetK;
Value newSmem = builder.create<tensor::ExtractSliceOp>(
v.getLoc(),
// TODO: encoding?
RankedTensorType::get(shape, elementType, type.getEncoding()), v,
SmallVector<OpFoldResult>{intAttr(offset[0]), intAttr(offset[1])},
SmallVector<OpFoldResult>{intAttr(shape[0]), intAttr(shape[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1)});
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding);
Value prefetchSlice = builder.create<triton::gpu::ConvertLayoutOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
return prefetchSlice;
}
LogicalResult Prefetcher::initialize() {
Block *loop = forOp.getBody();
SmallVector<triton::DotOp> dotsInFor;
for (Operation &op : *loop)
if (auto dotOp = dyn_cast<triton::DotOp>(op))
dotsInFor.push_back(dotOp);
if (dotsInFor.empty())
return failure();
// returns source of cvt
auto getPrefetchSrc = [](Value v) -> Value {
// TODO: Check if the layout of src is SharedEncodingAttr
if (auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>())
return cvt.src();
return Value();
};
auto getIncomingOp = [this](Value v) -> Value {
if (auto arg = v.dyn_cast<BlockArgument>())
if (arg.getOwner()->getParentOp() == forOp.getOperation())
return forOp.getOpOperandForRegionIterArg(arg).get();
return Value();
};
auto getYieldOp = [this](Value v) -> Value {
auto arg = v.cast<BlockArgument>();
unsigned yieldIdx = arg.getArgNumber() - forOp.getNumInductionVars();
return yieldOp.getOperand(yieldIdx);
};
for (triton::DotOp dot : dotsInFor) {
Value aSmem = getPrefetchSrc(dot.a());
Value bSmem = getPrefetchSrc(dot.b());
if (aSmem && bSmem) {
Value aHeaderDef = getIncomingOp(aSmem);
Value bHeaderDef = getIncomingOp(bSmem);
// Only prefetch loop arg
if (aHeaderDef && bHeaderDef) {
dots.insert(dot);
dot2aHeaderDef[dot] = aHeaderDef;
dot2bHeaderDef[dot] = bHeaderDef;
dot2aLoopArg[dot] = aSmem;
dot2bLoopArg[dot] = bSmem;
dot2aYield[dot] = getYieldOp(aSmem);
dot2bYield[dot] = getYieldOp(bSmem);
}
}
}
return success();
}
void Prefetcher::emitPrologue() {
OpBuilder builder(forOp);
for (Value dot : dots) {
Attribute dotEncoding =
dot.getType().cast<RankedTensorType>().getEncoding();
Value aPrefetched =
generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder);
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().a()] = aPrefetched;
Value bPrefetched =
generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder);
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().b()] = bPrefetched;
}
}
scf::ForOp Prefetcher::createNewForOp() {
OpBuilder builder(forOp);
SmallVector<Value> loopArgs;
for (auto v : forOp.getIterOperands())
loopArgs.push_back(v);
for (Value dot : dots) {
loopArgs.push_back(
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().a()]);
loopArgs.push_back(
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().b()]);
}
auto newForOp = builder.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), loopArgs);
auto largestPow2 = [](int64_t n) -> int64_t {
while ((n & (n - 1)) != 0)
n = n & (n - 1);
return n;
};
builder.setInsertionPointToStart(newForOp.getBody());
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = nullptr;
auto dot = dyn_cast<triton::DotOp>(&op);
if (dots.contains(dot)) {
Attribute dotEncoding =
dot.getType().cast<RankedTensorType>().getEncoding();
// prefetched dot
Operation *firstDot = builder.clone(*dot, mapping);
if (Value a = operand2headPrefetch.lookup(dot.a()))
firstDot->setOperand(
0, newForOp.getRegionIterArgForOpOperand(*a.use_begin()));
if (Value b = operand2headPrefetch.lookup(dot.b()))
firstDot->setOperand(
1, newForOp.getRegionIterArgForOpOperand(*b.use_begin()));
// remaining part
int64_t kOff = prefetchWidth;
int64_t kRem = dot.a().getType().cast<RankedTensorType>().getShape()[1] -
prefetchWidth;
Operation *prevDot = firstDot;
while (kRem != 0) {
int64_t kShape = largestPow2(kRem);
Value aRem =
generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false,
dotEncoding, builder, kOff, kShape);
Value bRem =
generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false,
dotEncoding, builder, kOff, kShape);
newOp = builder.clone(*dot, mapping);
newOp->setOperand(0, aRem);
newOp->setOperand(1, bRem);
newOp->setOperand(2, prevDot->getResult(0));
prevDot = newOp;
kOff += kShape;
kRem -= kShape;
}
} else {
newOp = builder.clone(op, mapping);
}
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
}
// prefetch next iteration
SmallVector<Value> yieldValues;
for (Value v : forOp.getBody()->getTerminator()->getOperands())
yieldValues.push_back(mapping.lookup(v));
for (Value dot : dots) {
Attribute dotEncoding =
dot.getType().cast<RankedTensorType>().getEncoding();
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2aYield[dot]), 0,
true, dotEncoding, builder));
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2bYield[dot]), 1,
true, dotEncoding, builder));
}
// Update ops of yield
builder.create<scf::YieldOp>(yieldOp.getLoc(), yieldValues);
return newForOp;
}
struct PrefetchPass : public TritonGPUPrefetchBase<PrefetchPass> {
void runOnOperation() override {
getOperation()->walk([&](scf::ForOp forOp) {
Prefetcher prefetcher(forOp);
if (prefetcher.initialize().failed())
return;
prefetcher.emitPrologue();
scf::ForOp newForOp = prefetcher.createNewForOp();
// replace the original loop
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
forOp->erase();
});
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUPrefetchPass() {
return std::make_unique<PrefetchPass>();
}

View File

@@ -39,23 +39,23 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
return SwizzleInfo{vec, perPhase, maxPhase};
} else if (version == 2) {
auto eltTy = ty.getElementType();
std::vector<size_t> mat_shape = {8, 8,
2 * 64 / eltTy.getIntOrFloatBitWidth()};
std::vector<size_t> matShape = {8, 8,
2 * 64 / eltTy.getIntOrFloatBitWidth()};
// for now, disable swizzle when using transposed int8 tensor cores
bool is_int8_mma = ty.getElementType().isInteger(8);
if (is_int8_mma && order[0] == inner)
bool isInt8Mma = ty.getElementType().isInteger(8);
if (isInt8Mma && order[0] == inner)
return noSwizzling;
// compute swizzling for A operand
if (opIdx == 0) {
int vec = order[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
int mmaStride = order[0] == 1 ? mat_shape[0] : mat_shape[2];
int vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m
int mmaStride = order[0] == 1 ? matShape[0] : matShape[2];
int maxPhase = mmaStride / perPhase;
return SwizzleInfo{vec, perPhase, maxPhase};
}
// compute swizzling for B operand
else if (opIdx == 1) {
int vec = order[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
int mmaStride = order[0] == 1 ? mat_shape[2] : mat_shape[1];
int vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k
int mmaStride = order[0] == 1 ? matShape[2] : matShape[1];
int maxPhase = mmaStride / perPhase;
return SwizzleInfo{vec, perPhase, maxPhase};
} else {
@@ -67,32 +67,64 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
void runOnOperation() override {
Operation *op = getOperation();
op->walk([&](triton::DotOp dotOp) -> void {
OpBuilder builder(dotOp);
auto _retEncoding =
dotOp.getResult().getType().cast<RankedTensorType>().getEncoding();
auto retEncoding = _retEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>();
if (!retEncoding)
return;
for (int opIdx : {0, 1}) {
Value op = dotOp.getOperand(opIdx);
auto ty = op.getType().template cast<RankedTensorType>();
// compute new swizzled encoding
SwizzleInfo swizzle = getSwizzleMMA(opIdx, retEncoding, ty);
auto newEncoding = triton::gpu::SharedEncodingAttr::get(
&getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase,
ty.getEncoding()
.cast<triton::gpu::SharedEncodingAttr>()
.getOrder());
// create conversion
auto newType = RankedTensorType::get(ty.getShape(), ty.getElementType(),
newEncoding);
Operation *newOp = builder.create<triton::gpu::ConvertLayoutOp>(
op.getLoc(), newType, op);
// bind new op to dot operand
dotOp->replaceUsesOfWith(op, newOp->getResult(0));
// replace blocked -> dot_op with
// blocked -> shared -> dot_op in order to
// expose opportunities for swizzling
op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
auto tmpType =
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(
op->getContext(), 1, 1, 1, {1, 0}));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
}
});
op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto arg = cvtOp.getOperand();
auto retType = cvtOp.getResult().getType().cast<RankedTensorType>();
auto retEncoding =
retType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
auto argType = arg.getType().cast<RankedTensorType>();
auto argEncoding =
argType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
if (!argEncoding || !retEncoding)
return;
auto opIdx = retEncoding.getOpIdx();
// compute new swizzled encoding
auto parentEncoding =
retEncoding.getParent().dyn_cast<triton::gpu::MmaEncodingAttr>();
if (!parentEncoding)
return;
auto swizzleType = argType;
if (arg.getDefiningOp() &&
isa<tensor::ExtractSliceOp>(arg.getDefiningOp())) {
swizzleType = arg.getDefiningOp()
->getOperand(0)
.getType()
.cast<RankedTensorType>();
}
SwizzleInfo swizzle = getSwizzleMMA(opIdx, parentEncoding, swizzleType);
auto newEncoding = triton::gpu::SharedEncodingAttr::get(
&getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase,
argEncoding.getOrder());
// create conversion
auto newType = RankedTensorType::get(
argType.getShape(), argType.getElementType(), newEncoding);
Operation *newArg = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), newType, arg);
// bind new op to cvt operand
cvtOp->replaceUsesOfWith(arg, newArg->getResult(0));
});
}
};
} // anonymous namespace

View File

@@ -95,8 +95,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding =
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding && aEncoding.isa<triton::gpu::SharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::SharedEncodingAttr>())
if (aEncoding && aEncoding.isa<triton::gpu::DotOperandEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
return true;
return false;
});