[Triton-MLIR][OPTIMIZER] Cleaned up swizzling (#869)
Swizzling is no longer implemented as a separate pass. It is instead done in a specialized constructor of SharedEncodingAttr, and tested via google tests instead of triton-opt + filecheck. In the future we may want to implement it as a pass again once we have an additional dialect between TritonGPU and LLVM.
This commit is contained in:
@@ -8,7 +8,6 @@ add_mlir_dialect_library(TritonGPUTransforms
|
||||
Combine.cpp
|
||||
Pipeline.cpp
|
||||
Prefetch.cpp
|
||||
Swizzle.cpp
|
||||
TritonGPUConversion.cpp
|
||||
|
||||
DEPENDS
|
||||
|
@@ -74,10 +74,6 @@ class LoopPipeliner {
|
||||
/// returns a empty buffer of size <numStages, ...>
|
||||
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
|
||||
|
||||
/// compute type of shared buffers (with swizzled shared layouts)
|
||||
RankedTensorType getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc,
|
||||
RankedTensorType tensorType);
|
||||
|
||||
public:
|
||||
LoopPipeliner(scf::ForOp forOp, int numStages)
|
||||
: forOp(forOp), numStages(numStages) {
|
||||
@@ -148,70 +144,6 @@ ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
|
||||
llvm_unreachable("Async copy's return should be of RankedTensorType");
|
||||
}
|
||||
|
||||
// TODO: I copied the code from Swizzle.cpp. Should find a way to unify the
|
||||
// code path.
|
||||
// Swizzle has to be performed before pipeline for now. If we do swizzle
|
||||
// after pipeline, we need to propagate the swizzled layout to all
|
||||
// operands that is an alias of the swizzled tensor. The alias analysis
|
||||
// component maybe helpful for this purpose.
|
||||
RankedTensorType
|
||||
LoopPipeliner::getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc,
|
||||
RankedTensorType ty) {
|
||||
int opIdx = dotOpEnc.getOpIdx();
|
||||
int vec = 1;
|
||||
int maxPhase = 1;
|
||||
int perPhase = 1;
|
||||
llvm::SmallVector<unsigned> order;
|
||||
if (auto mmaEnc = dotOpEnc.getParent().dyn_cast<ttg::MmaEncodingAttr>()) {
|
||||
// Only support row major for now
|
||||
// TODO(Keren): check why column major code crashes
|
||||
order = {1, 0};
|
||||
int version = mmaEnc.getVersion();
|
||||
auto tyEncoding = ty.getEncoding().cast<ttg::BlockedEncodingAttr>();
|
||||
// number of rows per phase
|
||||
perPhase = 128 / (ty.getShape()[order[0]] *
|
||||
(ty.getElementType().getIntOrFloatBitWidth() / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
|
||||
// index of the inner dimension in `order`
|
||||
unsigned inner = (opIdx == 0) ? 0 : 1;
|
||||
if (version == 1) {
|
||||
maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
// TODO: handle rep (see
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
|
||||
} else if (version == 2) {
|
||||
auto eltTy = ty.getElementType();
|
||||
std::vector<size_t> matShape = {8, 8,
|
||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
if (ty.getElementType().isInteger(8) && order[0] == inner)
|
||||
perPhase = 1;
|
||||
else {
|
||||
if (opIdx == 0) { // compute swizzling for A operand
|
||||
vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = order[0] == 1 ? matShape[0] : matShape[2];
|
||||
maxPhase = mmaStride / perPhase;
|
||||
} else if (opIdx == 1) { // compute swizzling for B operand
|
||||
vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = order[0] == 1 ? matShape[2] : matShape[1];
|
||||
maxPhase = mmaStride / perPhase;
|
||||
} else
|
||||
llvm_unreachable("invalid operand index");
|
||||
}
|
||||
} else // version not in [1, 2]
|
||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
||||
} else { // If the layout of dot is not mma, we don't need to swizzle
|
||||
auto blockedEnc = dotOpEnc.getParent().cast<ttg::BlockedEncodingAttr>();
|
||||
order = llvm::SmallVector<unsigned>(blockedEnc.getOrder().begin(),
|
||||
blockedEnc.getOrder().end());
|
||||
}
|
||||
auto newEncoding = ttg::SharedEncodingAttr::get(ty.getContext(), vec,
|
||||
perPhase, maxPhase, order);
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(), ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
return RankedTensorType::get(bufferShape, ty.getElementType(), newEncoding);
|
||||
}
|
||||
|
||||
/// A load instruction can be pipelined if:
|
||||
/// - the load doesn't depend on any other loads (after loop peeling)
|
||||
/// - (?) this load is not a loop-invariant value (we should run LICM before
|
||||
@@ -264,8 +196,14 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
||||
isCandiate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
loadsBufferType[loadOp] = getSwizzleType(
|
||||
dotOpEnc, loadOp.getType().cast<RankedTensorType>());
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
||||
ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
auto sharedEnc = ttg::SharedEncodingAttr::get(
|
||||
ty.getContext(), dotOpEnc, ty.getShape(), ty.getElementType());
|
||||
loadsBufferType[loadOp] = RankedTensorType::get(
|
||||
bufferShape, ty.getElementType(), sharedEnc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,134 +0,0 @@
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
|
||||
SwizzlePass() = default;
|
||||
|
||||
struct SwizzleInfo {
|
||||
int vec;
|
||||
int perPhase;
|
||||
int maxPhase;
|
||||
};
|
||||
|
||||
SwizzleInfo getSwizzleMMA(int opIdx, triton::gpu::MmaEncodingAttr retEncoding,
|
||||
RankedTensorType ty) {
|
||||
SwizzleInfo noSwizzling = {1, 1, 1};
|
||||
int version = retEncoding.getVersion();
|
||||
auto tyEncoding = ty.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto order = tyEncoding.getOrder();
|
||||
// number of rows per phase
|
||||
int perPhase = 128 / (ty.getShape()[order[0]] *
|
||||
(ty.getElementType().getIntOrFloatBitWidth() / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
// index of the inner dimension in `order`
|
||||
size_t inner = (opIdx == 0) ? 0 : 1;
|
||||
if (version == 1) {
|
||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
// TODO: handle rep (see
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
|
||||
int vec = 1;
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
} else if (version == 2) {
|
||||
auto eltTy = ty.getElementType();
|
||||
std::vector<size_t> matShape = {8, 8,
|
||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
bool isInt8Mma = ty.getElementType().isInteger(8);
|
||||
if (isInt8Mma && order[0] == inner)
|
||||
return noSwizzling;
|
||||
// compute swizzling for A operand
|
||||
if (opIdx == 0) {
|
||||
int vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = order[0] == 1 ? matShape[0] : matShape[2];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
}
|
||||
// compute swizzling for B operand
|
||||
else if (opIdx == 1) {
|
||||
int vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = order[0] == 1 ? matShape[2] : matShape[1];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
} else {
|
||||
llvm_unreachable("invalid operand index");
|
||||
}
|
||||
} else
|
||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
// replace blocked -> dot_op with
|
||||
// blocked -> shared -> dot_op in order to
|
||||
// expose opportunities for swizzling
|
||||
op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
|
||||
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
auto tmpType =
|
||||
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(
|
||||
op->getContext(), 1, 1, 1, {1, 0}));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
}
|
||||
});
|
||||
|
||||
op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto arg = cvtOp.getOperand();
|
||||
auto retType = cvtOp.getResult().getType().cast<RankedTensorType>();
|
||||
auto retEncoding =
|
||||
retType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
auto argType = arg.getType().cast<RankedTensorType>();
|
||||
auto argEncoding =
|
||||
argType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (!argEncoding || !retEncoding)
|
||||
return;
|
||||
auto opIdx = retEncoding.getOpIdx();
|
||||
// compute new swizzled encoding
|
||||
auto parentEncoding =
|
||||
retEncoding.getParent().dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (!parentEncoding)
|
||||
return;
|
||||
auto swizzleType = argType;
|
||||
if (arg.getDefiningOp() &&
|
||||
isa<tensor::ExtractSliceOp>(arg.getDefiningOp())) {
|
||||
swizzleType = arg.getDefiningOp()
|
||||
->getOperand(0)
|
||||
.getType()
|
||||
.cast<RankedTensorType>();
|
||||
}
|
||||
SwizzleInfo swizzle = getSwizzleMMA(opIdx, parentEncoding, swizzleType);
|
||||
auto newEncoding = triton::gpu::SharedEncodingAttr::get(
|
||||
&getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase,
|
||||
argEncoding.getOrder());
|
||||
// create conversion
|
||||
auto newType = RankedTensorType::get(
|
||||
argType.getShape(), argType.getElementType(), newEncoding);
|
||||
Operation *newArg = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), newType, arg);
|
||||
// bind new op to cvt operand
|
||||
cvtOp->replaceUsesOfWith(arg, newArg->getResult(0));
|
||||
});
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUSwizzlePass() {
|
||||
return std::make_unique<SwizzlePass>();
|
||||
}
|
Reference in New Issue
Block a user