[Triton-MLIR][OPTIMIZER] Cleaned up swizzling (#869)

Swizzling is no longer implemented as a separate pass. It is instead
done in a specialized constructor of SharedEncodingAttr, and tested via
google tests instead of triton-opt + filecheck.

In the future we may want to implement it as a pass again once we have
an additional dialect between TritonGPU and LLVM.
This commit is contained in:
Philippe Tillet
2022-11-10 12:05:46 -08:00
committed by GitHub
parent 2aa538ec2e
commit f40c63fb03
14 changed files with 170 additions and 318 deletions

View File

@@ -74,10 +74,6 @@ class LoopPipeliner {
/// returns a empty buffer of size <numStages, ...>
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
/// compute type of shared buffers (with swizzled shared layouts)
RankedTensorType getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc,
RankedTensorType tensorType);
public:
LoopPipeliner(scf::ForOp forOp, int numStages)
: forOp(forOp), numStages(numStages) {
@@ -148,70 +144,6 @@ ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
llvm_unreachable("Async copy's return should be of RankedTensorType");
}
// TODO: I copied the code from Swizzle.cpp. Should find a way to unify the
// code path.
// Swizzle has to be performed before pipeline for now. If we do swizzle
// after pipeline, we need to propagate the swizzled layout to all
// operands that is an alias of the swizzled tensor. The alias analysis
// component maybe helpful for this purpose.
RankedTensorType
LoopPipeliner::getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc,
RankedTensorType ty) {
int opIdx = dotOpEnc.getOpIdx();
int vec = 1;
int maxPhase = 1;
int perPhase = 1;
llvm::SmallVector<unsigned> order;
if (auto mmaEnc = dotOpEnc.getParent().dyn_cast<ttg::MmaEncodingAttr>()) {
// Only support row major for now
// TODO(Keren): check why column major code crashes
order = {1, 0};
int version = mmaEnc.getVersion();
auto tyEncoding = ty.getEncoding().cast<ttg::BlockedEncodingAttr>();
// number of rows per phase
perPhase = 128 / (ty.getShape()[order[0]] *
(ty.getElementType().getIntOrFloatBitWidth() / 8));
perPhase = std::max<int>(perPhase, 1);
// index of the inner dimension in `order`
unsigned inner = (opIdx == 0) ? 0 : 1;
if (version == 1) {
maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
// TODO: handle rep (see
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
} else if (version == 2) {
auto eltTy = ty.getElementType();
std::vector<size_t> matShape = {8, 8,
2 * 64 / eltTy.getIntOrFloatBitWidth()};
// for now, disable swizzle when using transposed int8 tensor cores
if (ty.getElementType().isInteger(8) && order[0] == inner)
perPhase = 1;
else {
if (opIdx == 0) { // compute swizzling for A operand
vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m
int mmaStride = order[0] == 1 ? matShape[0] : matShape[2];
maxPhase = mmaStride / perPhase;
} else if (opIdx == 1) { // compute swizzling for B operand
vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k
int mmaStride = order[0] == 1 ? matShape[2] : matShape[1];
maxPhase = mmaStride / perPhase;
} else
llvm_unreachable("invalid operand index");
}
} else // version not in [1, 2]
llvm_unreachable("unsupported swizzling for provided MMA version");
} else { // If the layout of dot is not mma, we don't need to swizzle
auto blockedEnc = dotOpEnc.getParent().cast<ttg::BlockedEncodingAttr>();
order = llvm::SmallVector<unsigned>(blockedEnc.getOrder().begin(),
blockedEnc.getOrder().end());
}
auto newEncoding = ttg::SharedEncodingAttr::get(ty.getContext(), vec,
perPhase, maxPhase, order);
SmallVector<int64_t> bufferShape(ty.getShape().begin(), ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
return RankedTensorType::get(bufferShape, ty.getElementType(), newEncoding);
}
/// A load instruction can be pipelined if:
/// - the load doesn't depend on any other loads (after loop peeling)
/// - (?) this load is not a loop-invariant value (we should run LICM before
@@ -264,8 +196,14 @@ LogicalResult LoopPipeliner::initialize() {
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
isCandiate = true;
loadsMapping[loadOp] = convertLayout;
loadsBufferType[loadOp] = getSwizzleType(
dotOpEnc, loadOp.getType().cast<RankedTensorType>());
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(), ty.getElementType());
loadsBufferType[loadOp] = RankedTensorType::get(
bufferShape, ty.getElementType(), sharedEnc);
}
}
}