[Triton-MLIR][OPTIMIZER] Cleaned up swizzling (#869)
Swizzling is no longer implemented as a separate pass. It is instead done in a specialized constructor of SharedEncodingAttr, and tested via google tests instead of triton-opt + filecheck. In the future we may want to implement it as a pass again once we have an additional dialect between TritonGPU and LLVM.
This commit is contained in:
@@ -71,6 +71,70 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
|||||||
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
|
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||||
|
"ArrayRef<int64_t>":$shape,
|
||||||
|
"Type":$eltTy), [{
|
||||||
|
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
|
||||||
|
// Only support row major for now
|
||||||
|
// TODO(Keren): check why column major code crashes
|
||||||
|
SmallVector<unsigned> order = {1, 0};
|
||||||
|
|
||||||
|
if(!mmaEnc)
|
||||||
|
return $_get(context, 1, 1, 1, order);
|
||||||
|
|
||||||
|
int version = mmaEnc.getVersion();
|
||||||
|
int opIdx = dotOpEnc.getOpIdx();
|
||||||
|
|
||||||
|
// number of rows per phase
|
||||||
|
int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8));
|
||||||
|
perPhase = std::max<int>(perPhase, 1);
|
||||||
|
|
||||||
|
// index of the inner dimension in `order`
|
||||||
|
unsigned inner = (opIdx == 0) ? 0 : 1;
|
||||||
|
|
||||||
|
// ---- begin version 1 ----
|
||||||
|
// TODO: handle rep (see
|
||||||
|
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
|
||||||
|
if (version == 1) {
|
||||||
|
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||||
|
return $_get(context, 1, perPhase, maxPhase, order);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- begin version 2 ----
|
||||||
|
if (version == 2) {
|
||||||
|
std::vector<size_t> matShape = {8, 8,
|
||||||
|
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
||||||
|
// for now, disable swizzle when using transposed int8 tensor cores
|
||||||
|
if (eltTy.isInteger(8) && order[0] == inner)
|
||||||
|
return $_get(context, 1, 1, 1, order);
|
||||||
|
|
||||||
|
// --- handle A operand ---
|
||||||
|
if (opIdx == 0) { // compute swizzling for A operand
|
||||||
|
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
|
||||||
|
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
|
||||||
|
int maxPhase = mmaStride / perPhase;
|
||||||
|
return $_get(context, vec, perPhase, maxPhase, order);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- handle B operand ---
|
||||||
|
if (opIdx == 1) {
|
||||||
|
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
|
||||||
|
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
|
||||||
|
int maxPhase = mmaStride / perPhase;
|
||||||
|
return $_get(context, vec, perPhase, maxPhase, order);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm_unreachable("invalid operand index");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- not implemented ----
|
||||||
|
llvm_unreachable("unsupported swizzling for provided MMA version");
|
||||||
|
|
||||||
|
|
||||||
|
}]>
|
||||||
|
];
|
||||||
|
|
||||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -11,8 +11,6 @@ std::unique_ptr<Pass> createTritonGPUPrefetchPass();
|
|||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
|
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUSwizzlePass();
|
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUCoalescePass();
|
std::unique_ptr<Pass> createTritonGPUCoalescePass();
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass();
|
std::unique_ptr<Pass> createTritonGPUCombineOpsPass();
|
||||||
|
@@ -65,18 +65,6 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
|||||||
"mlir::triton::TritonDialect"];
|
"mlir::triton::TritonDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TritonGPUSwizzle : Pass<"tritongpu-swizzle", "mlir::ModuleOp"> {
|
|
||||||
let summary = "swizzle";
|
|
||||||
|
|
||||||
let description = [{
|
|
||||||
Inserts conversions to swizzled layout so as to avoid shared memory bank conflicts.
|
|
||||||
}];
|
|
||||||
|
|
||||||
let constructor = "mlir::createTritonGPUSwizzlePass()";
|
|
||||||
|
|
||||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
|
||||||
}
|
|
||||||
|
|
||||||
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
||||||
let summary = "canonicalize scf.ForOp ops";
|
let summary = "canonicalize scf.ForOp ops";
|
||||||
|
|
||||||
|
@@ -2707,6 +2707,7 @@ public:
|
|||||||
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
||||||
}
|
}
|
||||||
// TODO: to be implemented
|
// TODO: to be implemented
|
||||||
|
llvm_unreachable("unsupported layout conversion");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -5763,6 +5764,35 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
|
|
||||||
class ConvertTritonGPUToLLVM
|
class ConvertTritonGPUToLLVM
|
||||||
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||||
|
|
||||||
|
private:
|
||||||
|
void decomposeBlockedToDotOperand(ModuleOp mod) {
|
||||||
|
// replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
|
||||||
|
// because the codegen doesn't handle `blocked -> dot_op` directly
|
||||||
|
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||||
|
OpBuilder builder(cvtOp);
|
||||||
|
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||||
|
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||||
|
auto srcBlocked =
|
||||||
|
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||||
|
auto dstDotOp =
|
||||||
|
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||||
|
if (srcBlocked && dstDotOp) {
|
||||||
|
auto tmpType = RankedTensorType::get(
|
||||||
|
dstType.getShape(), dstType.getElementType(),
|
||||||
|
triton::gpu::SharedEncodingAttr::get(mod.getContext(), dstDotOp,
|
||||||
|
srcType.getShape(),
|
||||||
|
srcType.getElementType()));
|
||||||
|
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||||
|
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||||
|
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||||
|
cvtOp.getLoc(), dstType, tmp);
|
||||||
|
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||||
|
cvtOp.erase();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
ConvertTritonGPUToLLVM() = default;
|
ConvertTritonGPUToLLVM() = default;
|
||||||
|
|
||||||
@@ -5779,15 +5809,19 @@ public:
|
|||||||
|
|
||||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||||
|
|
||||||
// step 1: Allocate shared memories and insert barriers
|
// step 1: Decompose unoptimized layout conversions to use shared memory
|
||||||
// step 2: Convert SCF to CFG
|
// step 2: Allocate shared memories and insert barriers
|
||||||
// step 3: Convert FuncOp to LLVMFuncOp via partial conversion
|
// step 3: Convert SCF to CFG
|
||||||
// step 4: Convert the rest of ops via partial conversion
|
// step 4: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||||
|
// step 5: Convert the rest of ops via partial conversion
|
||||||
// The reason for putting step 1 before step 2 is that the membar analysis
|
// The reason for putting step 1 before step 2 is that the membar analysis
|
||||||
// currently only supports SCF but not CFG.
|
// currently only supports SCF but not CFG.
|
||||||
// The reason for a separation between 1/4 is that, step 3 is out of
|
// The reason for a separation between 1/4 is that, step 3 is out of
|
||||||
// the scope of Dialect Conversion, thus we need to make sure the smem
|
// the scope of Dialect Conversion, thus we need to make sure the smem
|
||||||
// is not revised during the conversion of step 4.
|
// is not revised during the conversion of step 4.
|
||||||
|
|
||||||
|
decomposeBlockedToDotOperand(mod);
|
||||||
|
|
||||||
Allocation allocation(mod);
|
Allocation allocation(mod);
|
||||||
MembarAnalysis membar(&allocation);
|
MembarAnalysis membar(&allocation);
|
||||||
|
|
||||||
|
@@ -8,7 +8,6 @@ add_mlir_dialect_library(TritonGPUTransforms
|
|||||||
Combine.cpp
|
Combine.cpp
|
||||||
Pipeline.cpp
|
Pipeline.cpp
|
||||||
Prefetch.cpp
|
Prefetch.cpp
|
||||||
Swizzle.cpp
|
|
||||||
TritonGPUConversion.cpp
|
TritonGPUConversion.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
|
@@ -74,10 +74,6 @@ class LoopPipeliner {
|
|||||||
/// returns a empty buffer of size <numStages, ...>
|
/// returns a empty buffer of size <numStages, ...>
|
||||||
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
|
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
|
||||||
|
|
||||||
/// compute type of shared buffers (with swizzled shared layouts)
|
|
||||||
RankedTensorType getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc,
|
|
||||||
RankedTensorType tensorType);
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LoopPipeliner(scf::ForOp forOp, int numStages)
|
LoopPipeliner(scf::ForOp forOp, int numStages)
|
||||||
: forOp(forOp), numStages(numStages) {
|
: forOp(forOp), numStages(numStages) {
|
||||||
@@ -148,70 +144,6 @@ ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
|
|||||||
llvm_unreachable("Async copy's return should be of RankedTensorType");
|
llvm_unreachable("Async copy's return should be of RankedTensorType");
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: I copied the code from Swizzle.cpp. Should find a way to unify the
|
|
||||||
// code path.
|
|
||||||
// Swizzle has to be performed before pipeline for now. If we do swizzle
|
|
||||||
// after pipeline, we need to propagate the swizzled layout to all
|
|
||||||
// operands that is an alias of the swizzled tensor. The alias analysis
|
|
||||||
// component maybe helpful for this purpose.
|
|
||||||
RankedTensorType
|
|
||||||
LoopPipeliner::getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc,
|
|
||||||
RankedTensorType ty) {
|
|
||||||
int opIdx = dotOpEnc.getOpIdx();
|
|
||||||
int vec = 1;
|
|
||||||
int maxPhase = 1;
|
|
||||||
int perPhase = 1;
|
|
||||||
llvm::SmallVector<unsigned> order;
|
|
||||||
if (auto mmaEnc = dotOpEnc.getParent().dyn_cast<ttg::MmaEncodingAttr>()) {
|
|
||||||
// Only support row major for now
|
|
||||||
// TODO(Keren): check why column major code crashes
|
|
||||||
order = {1, 0};
|
|
||||||
int version = mmaEnc.getVersion();
|
|
||||||
auto tyEncoding = ty.getEncoding().cast<ttg::BlockedEncodingAttr>();
|
|
||||||
// number of rows per phase
|
|
||||||
perPhase = 128 / (ty.getShape()[order[0]] *
|
|
||||||
(ty.getElementType().getIntOrFloatBitWidth() / 8));
|
|
||||||
perPhase = std::max<int>(perPhase, 1);
|
|
||||||
|
|
||||||
// index of the inner dimension in `order`
|
|
||||||
unsigned inner = (opIdx == 0) ? 0 : 1;
|
|
||||||
if (version == 1) {
|
|
||||||
maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
|
||||||
// TODO: handle rep (see
|
|
||||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
|
|
||||||
} else if (version == 2) {
|
|
||||||
auto eltTy = ty.getElementType();
|
|
||||||
std::vector<size_t> matShape = {8, 8,
|
|
||||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
|
||||||
// for now, disable swizzle when using transposed int8 tensor cores
|
|
||||||
if (ty.getElementType().isInteger(8) && order[0] == inner)
|
|
||||||
perPhase = 1;
|
|
||||||
else {
|
|
||||||
if (opIdx == 0) { // compute swizzling for A operand
|
|
||||||
vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m
|
|
||||||
int mmaStride = order[0] == 1 ? matShape[0] : matShape[2];
|
|
||||||
maxPhase = mmaStride / perPhase;
|
|
||||||
} else if (opIdx == 1) { // compute swizzling for B operand
|
|
||||||
vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k
|
|
||||||
int mmaStride = order[0] == 1 ? matShape[2] : matShape[1];
|
|
||||||
maxPhase = mmaStride / perPhase;
|
|
||||||
} else
|
|
||||||
llvm_unreachable("invalid operand index");
|
|
||||||
}
|
|
||||||
} else // version not in [1, 2]
|
|
||||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
|
||||||
} else { // If the layout of dot is not mma, we don't need to swizzle
|
|
||||||
auto blockedEnc = dotOpEnc.getParent().cast<ttg::BlockedEncodingAttr>();
|
|
||||||
order = llvm::SmallVector<unsigned>(blockedEnc.getOrder().begin(),
|
|
||||||
blockedEnc.getOrder().end());
|
|
||||||
}
|
|
||||||
auto newEncoding = ttg::SharedEncodingAttr::get(ty.getContext(), vec,
|
|
||||||
perPhase, maxPhase, order);
|
|
||||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(), ty.getShape().end());
|
|
||||||
bufferShape.insert(bufferShape.begin(), numStages);
|
|
||||||
return RankedTensorType::get(bufferShape, ty.getElementType(), newEncoding);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A load instruction can be pipelined if:
|
/// A load instruction can be pipelined if:
|
||||||
/// - the load doesn't depend on any other loads (after loop peeling)
|
/// - the load doesn't depend on any other loads (after loop peeling)
|
||||||
/// - (?) this load is not a loop-invariant value (we should run LICM before
|
/// - (?) this load is not a loop-invariant value (we should run LICM before
|
||||||
@@ -264,8 +196,14 @@ LogicalResult LoopPipeliner::initialize() {
|
|||||||
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
||||||
isCandiate = true;
|
isCandiate = true;
|
||||||
loadsMapping[loadOp] = convertLayout;
|
loadsMapping[loadOp] = convertLayout;
|
||||||
loadsBufferType[loadOp] = getSwizzleType(
|
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||||
dotOpEnc, loadOp.getType().cast<RankedTensorType>());
|
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
||||||
|
ty.getShape().end());
|
||||||
|
bufferShape.insert(bufferShape.begin(), numStages);
|
||||||
|
auto sharedEnc = ttg::SharedEncodingAttr::get(
|
||||||
|
ty.getContext(), dotOpEnc, ty.getShape(), ty.getElementType());
|
||||||
|
loadsBufferType[loadOp] = RankedTensorType::get(
|
||||||
|
bufferShape, ty.getElementType(), sharedEnc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,134 +0,0 @@
|
|||||||
#include "mlir/Analysis/SliceAnalysis.h"
|
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
||||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
using namespace mlir::triton;
|
|
||||||
|
|
||||||
#define GEN_PASS_CLASSES
|
|
||||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
|
|
||||||
SwizzlePass() = default;
|
|
||||||
|
|
||||||
struct SwizzleInfo {
|
|
||||||
int vec;
|
|
||||||
int perPhase;
|
|
||||||
int maxPhase;
|
|
||||||
};
|
|
||||||
|
|
||||||
SwizzleInfo getSwizzleMMA(int opIdx, triton::gpu::MmaEncodingAttr retEncoding,
|
|
||||||
RankedTensorType ty) {
|
|
||||||
SwizzleInfo noSwizzling = {1, 1, 1};
|
|
||||||
int version = retEncoding.getVersion();
|
|
||||||
auto tyEncoding = ty.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
|
||||||
auto order = tyEncoding.getOrder();
|
|
||||||
// number of rows per phase
|
|
||||||
int perPhase = 128 / (ty.getShape()[order[0]] *
|
|
||||||
(ty.getElementType().getIntOrFloatBitWidth() / 8));
|
|
||||||
perPhase = std::max<int>(perPhase, 1);
|
|
||||||
// index of the inner dimension in `order`
|
|
||||||
size_t inner = (opIdx == 0) ? 0 : 1;
|
|
||||||
if (version == 1) {
|
|
||||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
|
||||||
// TODO: handle rep (see
|
|
||||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
|
|
||||||
int vec = 1;
|
|
||||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
|
||||||
} else if (version == 2) {
|
|
||||||
auto eltTy = ty.getElementType();
|
|
||||||
std::vector<size_t> matShape = {8, 8,
|
|
||||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
|
||||||
// for now, disable swizzle when using transposed int8 tensor cores
|
|
||||||
bool isInt8Mma = ty.getElementType().isInteger(8);
|
|
||||||
if (isInt8Mma && order[0] == inner)
|
|
||||||
return noSwizzling;
|
|
||||||
// compute swizzling for A operand
|
|
||||||
if (opIdx == 0) {
|
|
||||||
int vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m
|
|
||||||
int mmaStride = order[0] == 1 ? matShape[0] : matShape[2];
|
|
||||||
int maxPhase = mmaStride / perPhase;
|
|
||||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
|
||||||
}
|
|
||||||
// compute swizzling for B operand
|
|
||||||
else if (opIdx == 1) {
|
|
||||||
int vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k
|
|
||||||
int mmaStride = order[0] == 1 ? matShape[2] : matShape[1];
|
|
||||||
int maxPhase = mmaStride / perPhase;
|
|
||||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
|
||||||
} else {
|
|
||||||
llvm_unreachable("invalid operand index");
|
|
||||||
}
|
|
||||||
} else
|
|
||||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
|
||||||
}
|
|
||||||
|
|
||||||
void runOnOperation() override {
|
|
||||||
Operation *op = getOperation();
|
|
||||||
// replace blocked -> dot_op with
|
|
||||||
// blocked -> shared -> dot_op in order to
|
|
||||||
// expose opportunities for swizzling
|
|
||||||
op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
|
||||||
OpBuilder builder(cvtOp);
|
|
||||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
|
||||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
|
||||||
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
|
|
||||||
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
|
||||||
auto tmpType =
|
|
||||||
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
|
|
||||||
triton::gpu::SharedEncodingAttr::get(
|
|
||||||
op->getContext(), 1, 1, 1, {1, 0}));
|
|
||||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
||||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
|
||||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
||||||
cvtOp.getLoc(), dstType, tmp);
|
|
||||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
|
||||||
OpBuilder builder(cvtOp);
|
|
||||||
auto arg = cvtOp.getOperand();
|
|
||||||
auto retType = cvtOp.getResult().getType().cast<RankedTensorType>();
|
|
||||||
auto retEncoding =
|
|
||||||
retType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
|
||||||
auto argType = arg.getType().cast<RankedTensorType>();
|
|
||||||
auto argEncoding =
|
|
||||||
argType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
|
||||||
if (!argEncoding || !retEncoding)
|
|
||||||
return;
|
|
||||||
auto opIdx = retEncoding.getOpIdx();
|
|
||||||
// compute new swizzled encoding
|
|
||||||
auto parentEncoding =
|
|
||||||
retEncoding.getParent().dyn_cast<triton::gpu::MmaEncodingAttr>();
|
|
||||||
if (!parentEncoding)
|
|
||||||
return;
|
|
||||||
auto swizzleType = argType;
|
|
||||||
if (arg.getDefiningOp() &&
|
|
||||||
isa<tensor::ExtractSliceOp>(arg.getDefiningOp())) {
|
|
||||||
swizzleType = arg.getDefiningOp()
|
|
||||||
->getOperand(0)
|
|
||||||
.getType()
|
|
||||||
.cast<RankedTensorType>();
|
|
||||||
}
|
|
||||||
SwizzleInfo swizzle = getSwizzleMMA(opIdx, parentEncoding, swizzleType);
|
|
||||||
auto newEncoding = triton::gpu::SharedEncodingAttr::get(
|
|
||||||
&getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase,
|
|
||||||
argEncoding.getOrder());
|
|
||||||
// create conversion
|
|
||||||
auto newType = RankedTensorType::get(
|
|
||||||
argType.getShape(), argType.getElementType(), newEncoding);
|
|
||||||
Operation *newArg = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
||||||
cvtOp.getLoc(), newType, arg);
|
|
||||||
// bind new op to cvt operand
|
|
||||||
cvtOp->replaceUsesOfWith(arg, newArg->getResult(0));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
std::unique_ptr<Pass> mlir::createTritonGPUSwizzlePass() {
|
|
||||||
return std::make_unique<SwizzlePass>();
|
|
||||||
}
|
|
@@ -1264,10 +1264,6 @@ void init_triton_ir(py::module &&m) {
|
|||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::createTritonGPUCombineOpsPass());
|
self.addPass(mlir::createTritonGPUCombineOpsPass());
|
||||||
})
|
})
|
||||||
.def("add_triton_gpu_swizzle_pass",
|
|
||||||
[](mlir::PassManager &self) {
|
|
||||||
self.addPass(mlir::createTritonGPUSwizzlePass());
|
|
||||||
})
|
|
||||||
.def("add_triton_gpu_to_llvm",
|
.def("add_triton_gpu_to_llvm",
|
||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||||
|
@@ -885,7 +885,6 @@ def ttir_to_ttgir(mod, num_warps, num_stages):
|
|||||||
pm.add_coalesce_pass()
|
pm.add_coalesce_pass()
|
||||||
pm.add_triton_gpu_combine_pass()
|
pm.add_triton_gpu_combine_pass()
|
||||||
pm.add_licm_pass()
|
pm.add_licm_pass()
|
||||||
pm.add_triton_gpu_swizzle_pass()
|
|
||||||
pm.add_triton_gpu_combine_pass()
|
pm.add_triton_gpu_combine_pass()
|
||||||
pm.add_cse_pass()
|
pm.add_cse_pass()
|
||||||
pm.run(mod)
|
pm.run(mod)
|
||||||
|
@@ -1,90 +0,0 @@
|
|||||||
// RUN: triton-opt %s -split-input-file -tritongpu-swizzle | FileCheck %s
|
|
||||||
|
|
||||||
#shared = #triton_gpu.shared<{vec=1, perPhase=1, maxPhase=1 ,order = [1, 0]}>
|
|
||||||
#mma1w = #triton_gpu.mma<{version=2, warpsPerCTA=[1, 1]}>
|
|
||||||
#mma2w = #triton_gpu.mma<{version=2, warpsPerCTA=[1, 2]}>
|
|
||||||
#mma4w = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 2]}>
|
|
||||||
#mma8w = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 4]}>
|
|
||||||
|
|
||||||
// CHECK: [[shared_v8p1m8:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
|
|
||||||
// CHECK: [[shared_v8p2m4:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
|
||||||
// CHECK: [[shared_v8p4m2:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0]}>
|
|
||||||
|
|
||||||
#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
|
||||||
#shared3 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0]}>
|
|
||||||
|
|
||||||
#mma1w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma1w}>
|
|
||||||
#mma1w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma1w}>
|
|
||||||
#mma2w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma2w}>
|
|
||||||
#mma2w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma2w}>
|
|
||||||
#mma4w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma4w}>
|
|
||||||
#mma4w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma4w}>
|
|
||||||
#mma8w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma8w}>
|
|
||||||
#mma8w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma8w}>
|
|
||||||
|
|
||||||
|
|
||||||
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|
||||||
// CHECK-LABEL: swizzle_mma_f16_128x256x64_w8
|
|
||||||
func @swizzle_mma_f16_128x256x64_w8(%A_SMEM: tensor<128x64xf16, #shared>, %B_SMEM: tensor<64x256xf16, #shared>) {
|
|
||||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma8w>
|
|
||||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]>
|
|
||||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x256xf16, {{.*}}>) -> tensor<64x256xf16, [[shared_v8p1m8]]>
|
|
||||||
%A = triton_gpu.convert_layout %A_SMEM : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #mma8w_op0>
|
|
||||||
%B = triton_gpu.convert_layout %B_SMEM : (tensor<64x256xf16, #shared>) -> tensor<64x256xf16, #mma8w_op1>
|
|
||||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #mma8w_op0> * tensor<64x256xf16, #mma8w_op1> -> tensor<128x256xf32, #mma8w>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
||||||
// CHECK-LABEL: swizzle_mma_f16_128x128x64_w4
|
|
||||||
func @swizzle_mma_f16_128x128x64_w4(%A_SMEM: tensor<128x64xf16, #shared>, %B_SMEM: tensor<64x128xf16, #shared>) {
|
|
||||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w>
|
|
||||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]>
|
|
||||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x128xf16, {{.*}}>) -> tensor<64x128xf16, [[shared_v8p1m8]]>
|
|
||||||
%A = triton_gpu.convert_layout %A_SMEM : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #mma4w_op0>
|
|
||||||
%B = triton_gpu.convert_layout %B_SMEM : (tensor<64x128xf16, #shared>) -> tensor<64x128xf16, #mma4w_op1>
|
|
||||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #mma4w_op0> * tensor<64x128xf16, #mma4w_op1> -> tensor<128x128xf32, #mma4w>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
||||||
// CHECK-LABEL: swizzle_mma_f16_128x128x32_w4
|
|
||||||
func @swizzle_mma_f16_128x128x32_w4(%A_SMEM: tensor<128x32xf16, #shared>, %B_SMEM: tensor<32x128xf16, #shared>) {
|
|
||||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w>
|
|
||||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x32xf16, {{.*}}>) -> tensor<128x32xf16, [[shared_v8p2m4]]>
|
|
||||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x128xf16, {{.*}}>) -> tensor<32x128xf16, [[shared_v8p1m8]]>
|
|
||||||
%A = triton_gpu.convert_layout %A_SMEM : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #mma4w_op0>
|
|
||||||
%B = triton_gpu.convert_layout %B_SMEM : (tensor<32x128xf16, #shared>) -> tensor<32x128xf16, #mma4w_op1>
|
|
||||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #mma4w_op0> * tensor<32x128xf16, #mma4w_op1> -> tensor<128x128xf32, #mma4w>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
|
||||||
// CHECK-LABEL: swizzle_mma_f16_32x32x32_w2
|
|
||||||
func @swizzle_mma_f16_32x32x32_w2(%A_SMEM: tensor<32x32xf16, #shared>, %B_SMEM: tensor<32x32xf16, #shared>) {
|
|
||||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma2w>
|
|
||||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]>
|
|
||||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]>
|
|
||||||
%A = triton_gpu.convert_layout %A_SMEM : (tensor<32x32xf16, #shared>) -> tensor<32x32xf16, #mma2w_op0>
|
|
||||||
%B = triton_gpu.convert_layout %B_SMEM : (tensor<32x32xf16, #shared>) -> tensor<32x32xf16, #mma2w_op1>
|
|
||||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<32x32xf16, #mma2w_op0> * tensor<32x32xf16, #mma2w_op1> -> tensor<32x32xf32, #mma2w>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
||||||
// CHECK-LABEL: swizzle_mma_f16_16x16x16_w1
|
|
||||||
func @swizzle_mma_f16_16x16x16_w1(%A_SMEM: tensor<16x16xf16, #shared>, %B_SMEM: tensor<16x16xf16, #shared>) {
|
|
||||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma1w>
|
|
||||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]>
|
|
||||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]>
|
|
||||||
%A = triton_gpu.convert_layout %A_SMEM : (tensor<16x16xf16, #shared>) -> tensor<16x16xf16, #mma1w_op0>
|
|
||||||
%B = triton_gpu.convert_layout %B_SMEM : (tensor<16x16xf16, #shared>) -> tensor<16x16xf16, #mma1w_op1>
|
|
||||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #mma1w_op0> * tensor<16x16xf16, #mma1w_op1> -> tensor<16x16xf32, #mma1w>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
@@ -26,3 +26,4 @@ endfunction()
|
|||||||
|
|
||||||
add_subdirectory(Analysis)
|
add_subdirectory(Analysis)
|
||||||
add_subdirectory(Conversion)
|
add_subdirectory(Conversion)
|
||||||
|
add_subdirectory(Dialect)
|
||||||
|
1
unittest/Dialect/CMakeLists.txt
Normal file
1
unittest/Dialect/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
add_subdirectory(TritonGPU)
|
6
unittest/Dialect/TritonGPU/CMakeLists.txt
Normal file
6
unittest/Dialect/TritonGPU/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
|
||||||
|
add_triton_ut(
|
||||||
|
NAME TestSwizzling
|
||||||
|
SRCS SwizzleTest.cpp
|
||||||
|
LIBS TritonGPUIR ${dialect_libs} ${conversion_libs}
|
||||||
|
)
|
52
unittest/Dialect/TritonGPU/SwizzleTest.cpp
Normal file
52
unittest/Dialect/TritonGPU/SwizzleTest.cpp
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using mlir::triton::gpu::SharedEncodingAttr;
|
||||||
|
|
||||||
|
struct swizzleParams {
|
||||||
|
int vec;
|
||||||
|
int perPhase;
|
||||||
|
int maxPhase;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ParamT {
|
||||||
|
std::array<int64_t, 2> shape;
|
||||||
|
int opIdx;
|
||||||
|
int typeWidth;
|
||||||
|
swizzleParams refSwizzle;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SwizzleDotOperandTestFixture : public ::testing::TestWithParam<ParamT> {
|
||||||
|
protected:
|
||||||
|
ParamType param;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(SwizzleDotOperandTestFixture, DotOperands) {
|
||||||
|
auto params = GetParam();
|
||||||
|
// init context
|
||||||
|
MLIRContext ctx;
|
||||||
|
ctx.loadDialect<triton::gpu::TritonGPUDialect>();
|
||||||
|
// create encoding
|
||||||
|
auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, {1, 1});
|
||||||
|
auto encoding =
|
||||||
|
triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent);
|
||||||
|
|
||||||
|
// create element type
|
||||||
|
Type eltType = IntegerType::get(&ctx, params.typeWidth);
|
||||||
|
auto layout = SharedEncodingAttr::get(&ctx, encoding, params.shape, eltType);
|
||||||
|
|
||||||
|
ASSERT_EQ(layout.getVec(), params.refSwizzle.vec);
|
||||||
|
ASSERT_EQ(layout.getPerPhase(), params.refSwizzle.perPhase);
|
||||||
|
ASSERT_EQ(layout.getMaxPhase(), params.refSwizzle.maxPhase);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TestDotOperands, SwizzleDotOperandTestFixture,
|
||||||
|
::testing::Values(ParamT{{128, 64}, 0, 16, {8, 1, 8}},
|
||||||
|
ParamT{{64, 256}, 1, 16, {8, 1, 8}},
|
||||||
|
ParamT{{128, 32}, 0, 16, {8, 2, 4}},
|
||||||
|
ParamT{{32, 128}, 1, 16, {8, 1, 8}},
|
||||||
|
ParamT{{32, 32}, 0, 16, {8, 2, 4}},
|
||||||
|
ParamT{{32, 32}, 1, 16, {8, 2, 4}},
|
||||||
|
ParamT{{16, 16}, 0, 16, {8, 4, 2}},
|
||||||
|
ParamT{{16, 16}, 1, 16, {8, 4, 2}}));
|
Reference in New Issue
Block a user