[Triton-MLIR][OPTIMIZER] Cleaned up swizzling (#869)
Swizzling is no longer implemented as a separate pass. It is instead done in a specialized constructor of SharedEncodingAttr, and tested via google tests instead of triton-opt + filecheck. In the future we may want to implement it as a pass again once we have an additional dialect between TritonGPU and LLVM.
This commit is contained in:
@@ -71,6 +71,70 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"Type":$eltTy), [{
|
||||
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
// Only support row major for now
|
||||
// TODO(Keren): check why column major code crashes
|
||||
SmallVector<unsigned> order = {1, 0};
|
||||
|
||||
if(!mmaEnc)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
|
||||
int version = mmaEnc.getVersion();
|
||||
int opIdx = dotOpEnc.getOpIdx();
|
||||
|
||||
// number of rows per phase
|
||||
int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
|
||||
// index of the inner dimension in `order`
|
||||
unsigned inner = (opIdx == 0) ? 0 : 1;
|
||||
|
||||
// ---- begin version 1 ----
|
||||
// TODO: handle rep (see
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
|
||||
if (version == 1) {
|
||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
return $_get(context, 1, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
// ---- begin version 2 ----
|
||||
if (version == 2) {
|
||||
std::vector<size_t> matShape = {8, 8,
|
||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
if (eltTy.isInteger(8) && order[0] == inner)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
|
||||
// --- handle A operand ---
|
||||
if (opIdx == 0) { // compute swizzling for A operand
|
||||
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
// --- handle B operand ---
|
||||
if (opIdx == 1) {
|
||||
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
llvm_unreachable("invalid operand index");
|
||||
}
|
||||
|
||||
// ---- not implemented ----
|
||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
||||
|
||||
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
|
@@ -11,8 +11,6 @@ std::unique_ptr<Pass> createTritonGPUPrefetchPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUSwizzlePass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCoalescePass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass();
|
||||
|
@@ -65,18 +65,6 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
||||
"mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUSwizzle : Pass<"tritongpu-swizzle", "mlir::ModuleOp"> {
|
||||
let summary = "swizzle";
|
||||
|
||||
let description = [{
|
||||
Inserts conversions to swizzled layout so as to avoid shared memory bank conflicts.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUSwizzlePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
||||
let summary = "canonicalize scf.ForOp ops";
|
||||
|
||||
|
@@ -2707,6 +2707,7 @@ public:
|
||||
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
||||
}
|
||||
// TODO: to be implemented
|
||||
llvm_unreachable("unsupported layout conversion");
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -5763,6 +5764,35 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
|
||||
class ConvertTritonGPUToLLVM
|
||||
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||
|
||||
private:
|
||||
void decomposeBlockedToDotOperand(ModuleOp mod) {
|
||||
// replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
|
||||
// because the codegen doesn't handle `blocked -> dot_op` directly
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
auto srcBlocked =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto dstDotOp =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (srcBlocked && dstDotOp) {
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(mod.getContext(), dstDotOp,
|
||||
srcType.getShape(),
|
||||
srcType.getElementType()));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
cvtOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
ConvertTritonGPUToLLVM() = default;
|
||||
|
||||
@@ -5779,15 +5809,19 @@ public:
|
||||
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
// step 1: Allocate shared memories and insert barriers
|
||||
// step 2: Convert SCF to CFG
|
||||
// step 3: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
// step 4: Convert the rest of ops via partial conversion
|
||||
// step 1: Decompose unoptimized layout conversions to use shared memory
|
||||
// step 2: Allocate shared memories and insert barriers
|
||||
// step 3: Convert SCF to CFG
|
||||
// step 4: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
// step 5: Convert the rest of ops via partial conversion
|
||||
// The reason for putting step 1 before step 2 is that the membar analysis
|
||||
// currently only supports SCF but not CFG.
|
||||
// The reason for a separation between 1/4 is that, step 3 is out of
|
||||
// the scope of Dialect Conversion, thus we need to make sure the smem
|
||||
// is not revised during the conversion of step 4.
|
||||
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
|
||||
Allocation allocation(mod);
|
||||
MembarAnalysis membar(&allocation);
|
||||
|
||||
|
@@ -8,7 +8,6 @@ add_mlir_dialect_library(TritonGPUTransforms
|
||||
Combine.cpp
|
||||
Pipeline.cpp
|
||||
Prefetch.cpp
|
||||
Swizzle.cpp
|
||||
TritonGPUConversion.cpp
|
||||
|
||||
DEPENDS
|
||||
|
@@ -74,10 +74,6 @@ class LoopPipeliner {
|
||||
/// returns a empty buffer of size <numStages, ...>
|
||||
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
|
||||
|
||||
/// compute type of shared buffers (with swizzled shared layouts)
|
||||
RankedTensorType getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc,
|
||||
RankedTensorType tensorType);
|
||||
|
||||
public:
|
||||
LoopPipeliner(scf::ForOp forOp, int numStages)
|
||||
: forOp(forOp), numStages(numStages) {
|
||||
@@ -148,70 +144,6 @@ ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
|
||||
llvm_unreachable("Async copy's return should be of RankedTensorType");
|
||||
}
|
||||
|
||||
// TODO: I copied the code from Swizzle.cpp. Should find a way to unify the
|
||||
// code path.
|
||||
// Swizzle has to be performed before pipeline for now. If we do swizzle
|
||||
// after pipeline, we need to propagate the swizzled layout to all
|
||||
// operands that is an alias of the swizzled tensor. The alias analysis
|
||||
// component maybe helpful for this purpose.
|
||||
RankedTensorType
|
||||
LoopPipeliner::getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc,
|
||||
RankedTensorType ty) {
|
||||
int opIdx = dotOpEnc.getOpIdx();
|
||||
int vec = 1;
|
||||
int maxPhase = 1;
|
||||
int perPhase = 1;
|
||||
llvm::SmallVector<unsigned> order;
|
||||
if (auto mmaEnc = dotOpEnc.getParent().dyn_cast<ttg::MmaEncodingAttr>()) {
|
||||
// Only support row major for now
|
||||
// TODO(Keren): check why column major code crashes
|
||||
order = {1, 0};
|
||||
int version = mmaEnc.getVersion();
|
||||
auto tyEncoding = ty.getEncoding().cast<ttg::BlockedEncodingAttr>();
|
||||
// number of rows per phase
|
||||
perPhase = 128 / (ty.getShape()[order[0]] *
|
||||
(ty.getElementType().getIntOrFloatBitWidth() / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
|
||||
// index of the inner dimension in `order`
|
||||
unsigned inner = (opIdx == 0) ? 0 : 1;
|
||||
if (version == 1) {
|
||||
maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
// TODO: handle rep (see
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
|
||||
} else if (version == 2) {
|
||||
auto eltTy = ty.getElementType();
|
||||
std::vector<size_t> matShape = {8, 8,
|
||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
if (ty.getElementType().isInteger(8) && order[0] == inner)
|
||||
perPhase = 1;
|
||||
else {
|
||||
if (opIdx == 0) { // compute swizzling for A operand
|
||||
vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = order[0] == 1 ? matShape[0] : matShape[2];
|
||||
maxPhase = mmaStride / perPhase;
|
||||
} else if (opIdx == 1) { // compute swizzling for B operand
|
||||
vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = order[0] == 1 ? matShape[2] : matShape[1];
|
||||
maxPhase = mmaStride / perPhase;
|
||||
} else
|
||||
llvm_unreachable("invalid operand index");
|
||||
}
|
||||
} else // version not in [1, 2]
|
||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
||||
} else { // If the layout of dot is not mma, we don't need to swizzle
|
||||
auto blockedEnc = dotOpEnc.getParent().cast<ttg::BlockedEncodingAttr>();
|
||||
order = llvm::SmallVector<unsigned>(blockedEnc.getOrder().begin(),
|
||||
blockedEnc.getOrder().end());
|
||||
}
|
||||
auto newEncoding = ttg::SharedEncodingAttr::get(ty.getContext(), vec,
|
||||
perPhase, maxPhase, order);
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(), ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
return RankedTensorType::get(bufferShape, ty.getElementType(), newEncoding);
|
||||
}
|
||||
|
||||
/// A load instruction can be pipelined if:
|
||||
/// - the load doesn't depend on any other loads (after loop peeling)
|
||||
/// - (?) this load is not a loop-invariant value (we should run LICM before
|
||||
@@ -264,8 +196,14 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
||||
isCandiate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
loadsBufferType[loadOp] = getSwizzleType(
|
||||
dotOpEnc, loadOp.getType().cast<RankedTensorType>());
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
||||
ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
auto sharedEnc = ttg::SharedEncodingAttr::get(
|
||||
ty.getContext(), dotOpEnc, ty.getShape(), ty.getElementType());
|
||||
loadsBufferType[loadOp] = RankedTensorType::get(
|
||||
bufferShape, ty.getElementType(), sharedEnc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,134 +0,0 @@
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
|
||||
SwizzlePass() = default;
|
||||
|
||||
struct SwizzleInfo {
|
||||
int vec;
|
||||
int perPhase;
|
||||
int maxPhase;
|
||||
};
|
||||
|
||||
SwizzleInfo getSwizzleMMA(int opIdx, triton::gpu::MmaEncodingAttr retEncoding,
|
||||
RankedTensorType ty) {
|
||||
SwizzleInfo noSwizzling = {1, 1, 1};
|
||||
int version = retEncoding.getVersion();
|
||||
auto tyEncoding = ty.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto order = tyEncoding.getOrder();
|
||||
// number of rows per phase
|
||||
int perPhase = 128 / (ty.getShape()[order[0]] *
|
||||
(ty.getElementType().getIntOrFloatBitWidth() / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
// index of the inner dimension in `order`
|
||||
size_t inner = (opIdx == 0) ? 0 : 1;
|
||||
if (version == 1) {
|
||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
// TODO: handle rep (see
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
|
||||
int vec = 1;
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
} else if (version == 2) {
|
||||
auto eltTy = ty.getElementType();
|
||||
std::vector<size_t> matShape = {8, 8,
|
||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
bool isInt8Mma = ty.getElementType().isInteger(8);
|
||||
if (isInt8Mma && order[0] == inner)
|
||||
return noSwizzling;
|
||||
// compute swizzling for A operand
|
||||
if (opIdx == 0) {
|
||||
int vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = order[0] == 1 ? matShape[0] : matShape[2];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
}
|
||||
// compute swizzling for B operand
|
||||
else if (opIdx == 1) {
|
||||
int vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = order[0] == 1 ? matShape[2] : matShape[1];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
} else {
|
||||
llvm_unreachable("invalid operand index");
|
||||
}
|
||||
} else
|
||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
// replace blocked -> dot_op with
|
||||
// blocked -> shared -> dot_op in order to
|
||||
// expose opportunities for swizzling
|
||||
op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
|
||||
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
auto tmpType =
|
||||
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(
|
||||
op->getContext(), 1, 1, 1, {1, 0}));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
}
|
||||
});
|
||||
|
||||
op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto arg = cvtOp.getOperand();
|
||||
auto retType = cvtOp.getResult().getType().cast<RankedTensorType>();
|
||||
auto retEncoding =
|
||||
retType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
auto argType = arg.getType().cast<RankedTensorType>();
|
||||
auto argEncoding =
|
||||
argType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (!argEncoding || !retEncoding)
|
||||
return;
|
||||
auto opIdx = retEncoding.getOpIdx();
|
||||
// compute new swizzled encoding
|
||||
auto parentEncoding =
|
||||
retEncoding.getParent().dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (!parentEncoding)
|
||||
return;
|
||||
auto swizzleType = argType;
|
||||
if (arg.getDefiningOp() &&
|
||||
isa<tensor::ExtractSliceOp>(arg.getDefiningOp())) {
|
||||
swizzleType = arg.getDefiningOp()
|
||||
->getOperand(0)
|
||||
.getType()
|
||||
.cast<RankedTensorType>();
|
||||
}
|
||||
SwizzleInfo swizzle = getSwizzleMMA(opIdx, parentEncoding, swizzleType);
|
||||
auto newEncoding = triton::gpu::SharedEncodingAttr::get(
|
||||
&getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase,
|
||||
argEncoding.getOrder());
|
||||
// create conversion
|
||||
auto newType = RankedTensorType::get(
|
||||
argType.getShape(), argType.getElementType(), newEncoding);
|
||||
Operation *newArg = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), newType, arg);
|
||||
// bind new op to cvt operand
|
||||
cvtOp->replaceUsesOfWith(arg, newArg->getResult(0));
|
||||
});
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUSwizzlePass() {
|
||||
return std::make_unique<SwizzlePass>();
|
||||
}
|
@@ -1264,10 +1264,6 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUCombineOpsPass());
|
||||
})
|
||||
.def("add_triton_gpu_swizzle_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUSwizzlePass());
|
||||
})
|
||||
.def("add_triton_gpu_to_llvm",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||
|
@@ -885,7 +885,6 @@ def ttir_to_ttgir(mod, num_warps, num_stages):
|
||||
pm.add_coalesce_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_licm_pass()
|
||||
pm.add_triton_gpu_swizzle_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.run(mod)
|
||||
|
@@ -1,90 +0,0 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-swizzle | FileCheck %s
|
||||
|
||||
#shared = #triton_gpu.shared<{vec=1, perPhase=1, maxPhase=1 ,order = [1, 0]}>
|
||||
#mma1w = #triton_gpu.mma<{version=2, warpsPerCTA=[1, 1]}>
|
||||
#mma2w = #triton_gpu.mma<{version=2, warpsPerCTA=[1, 2]}>
|
||||
#mma4w = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 2]}>
|
||||
#mma8w = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 4]}>
|
||||
|
||||
// CHECK: [[shared_v8p1m8:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
|
||||
// CHECK: [[shared_v8p2m4:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
// CHECK: [[shared_v8p4m2:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0]}>
|
||||
|
||||
#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#shared3 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0]}>
|
||||
|
||||
#mma1w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma1w}>
|
||||
#mma1w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma1w}>
|
||||
#mma2w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma2w}>
|
||||
#mma2w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma2w}>
|
||||
#mma4w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma4w}>
|
||||
#mma4w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma4w}>
|
||||
#mma8w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma8w}>
|
||||
#mma8w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma8w}>
|
||||
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
||||
// CHECK-LABEL: swizzle_mma_f16_128x256x64_w8
|
||||
func @swizzle_mma_f16_128x256x64_w8(%A_SMEM: tensor<128x64xf16, #shared>, %B_SMEM: tensor<64x256xf16, #shared>) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma8w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x256xf16, {{.*}}>) -> tensor<64x256xf16, [[shared_v8p1m8]]>
|
||||
%A = triton_gpu.convert_layout %A_SMEM : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #mma8w_op0>
|
||||
%B = triton_gpu.convert_layout %B_SMEM : (tensor<64x256xf16, #shared>) -> tensor<64x256xf16, #mma8w_op1>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #mma8w_op0> * tensor<64x256xf16, #mma8w_op1> -> tensor<128x256xf32, #mma8w>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: swizzle_mma_f16_128x128x64_w4
|
||||
func @swizzle_mma_f16_128x128x64_w4(%A_SMEM: tensor<128x64xf16, #shared>, %B_SMEM: tensor<64x128xf16, #shared>) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x128xf16, {{.*}}>) -> tensor<64x128xf16, [[shared_v8p1m8]]>
|
||||
%A = triton_gpu.convert_layout %A_SMEM : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #mma4w_op0>
|
||||
%B = triton_gpu.convert_layout %B_SMEM : (tensor<64x128xf16, #shared>) -> tensor<64x128xf16, #mma4w_op1>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #mma4w_op0> * tensor<64x128xf16, #mma4w_op1> -> tensor<128x128xf32, #mma4w>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: swizzle_mma_f16_128x128x32_w4
|
||||
func @swizzle_mma_f16_128x128x32_w4(%A_SMEM: tensor<128x32xf16, #shared>, %B_SMEM: tensor<32x128xf16, #shared>) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x32xf16, {{.*}}>) -> tensor<128x32xf16, [[shared_v8p2m4]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x128xf16, {{.*}}>) -> tensor<32x128xf16, [[shared_v8p1m8]]>
|
||||
%A = triton_gpu.convert_layout %A_SMEM : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #mma4w_op0>
|
||||
%B = triton_gpu.convert_layout %B_SMEM : (tensor<32x128xf16, #shared>) -> tensor<32x128xf16, #mma4w_op1>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #mma4w_op0> * tensor<32x128xf16, #mma4w_op1> -> tensor<128x128xf32, #mma4w>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: swizzle_mma_f16_32x32x32_w2
|
||||
func @swizzle_mma_f16_32x32x32_w2(%A_SMEM: tensor<32x32xf16, #shared>, %B_SMEM: tensor<32x32xf16, #shared>) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma2w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]>
|
||||
%A = triton_gpu.convert_layout %A_SMEM : (tensor<32x32xf16, #shared>) -> tensor<32x32xf16, #mma2w_op0>
|
||||
%B = triton_gpu.convert_layout %B_SMEM : (tensor<32x32xf16, #shared>) -> tensor<32x32xf16, #mma2w_op1>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<32x32xf16, #mma2w_op0> * tensor<32x32xf16, #mma2w_op1> -> tensor<32x32xf32, #mma2w>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: swizzle_mma_f16_16x16x16_w1
|
||||
func @swizzle_mma_f16_16x16x16_w1(%A_SMEM: tensor<16x16xf16, #shared>, %B_SMEM: tensor<16x16xf16, #shared>) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma1w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]>
|
||||
%A = triton_gpu.convert_layout %A_SMEM : (tensor<16x16xf16, #shared>) -> tensor<16x16xf16, #mma1w_op0>
|
||||
%B = triton_gpu.convert_layout %B_SMEM : (tensor<16x16xf16, #shared>) -> tensor<16x16xf16, #mma1w_op1>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #mma1w_op0> * tensor<16x16xf16, #mma1w_op1> -> tensor<16x16xf32, #mma1w>
|
||||
return
|
||||
}
|
||||
}
|
@@ -26,3 +26,4 @@ endfunction()
|
||||
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
|
1
unittest/Dialect/CMakeLists.txt
Normal file
1
unittest/Dialect/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(TritonGPU)
|
6
unittest/Dialect/TritonGPU/CMakeLists.txt
Normal file
6
unittest/Dialect/TritonGPU/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
|
||||
add_triton_ut(
|
||||
NAME TestSwizzling
|
||||
SRCS SwizzleTest.cpp
|
||||
LIBS TritonGPUIR ${dialect_libs} ${conversion_libs}
|
||||
)
|
52
unittest/Dialect/TritonGPU/SwizzleTest.cpp
Normal file
52
unittest/Dialect/TritonGPU/SwizzleTest.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using namespace mlir;
|
||||
using mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
struct swizzleParams {
|
||||
int vec;
|
||||
int perPhase;
|
||||
int maxPhase;
|
||||
};
|
||||
|
||||
struct ParamT {
|
||||
std::array<int64_t, 2> shape;
|
||||
int opIdx;
|
||||
int typeWidth;
|
||||
swizzleParams refSwizzle;
|
||||
};
|
||||
|
||||
class SwizzleDotOperandTestFixture : public ::testing::TestWithParam<ParamT> {
|
||||
protected:
|
||||
ParamType param;
|
||||
};
|
||||
|
||||
TEST_P(SwizzleDotOperandTestFixture, DotOperands) {
|
||||
auto params = GetParam();
|
||||
// init context
|
||||
MLIRContext ctx;
|
||||
ctx.loadDialect<triton::gpu::TritonGPUDialect>();
|
||||
// create encoding
|
||||
auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, {1, 1});
|
||||
auto encoding =
|
||||
triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent);
|
||||
|
||||
// create element type
|
||||
Type eltType = IntegerType::get(&ctx, params.typeWidth);
|
||||
auto layout = SharedEncodingAttr::get(&ctx, encoding, params.shape, eltType);
|
||||
|
||||
ASSERT_EQ(layout.getVec(), params.refSwizzle.vec);
|
||||
ASSERT_EQ(layout.getPerPhase(), params.refSwizzle.perPhase);
|
||||
ASSERT_EQ(layout.getMaxPhase(), params.refSwizzle.maxPhase);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestDotOperands, SwizzleDotOperandTestFixture,
|
||||
::testing::Values(ParamT{{128, 64}, 0, 16, {8, 1, 8}},
|
||||
ParamT{{64, 256}, 1, 16, {8, 1, 8}},
|
||||
ParamT{{128, 32}, 0, 16, {8, 2, 4}},
|
||||
ParamT{{32, 128}, 1, 16, {8, 1, 8}},
|
||||
ParamT{{32, 32}, 0, 16, {8, 2, 4}},
|
||||
ParamT{{32, 32}, 1, 16, {8, 2, 4}},
|
||||
ParamT{{16, 16}, 0, 16, {8, 4, 2}},
|
||||
ParamT{{16, 16}, 1, 16, {8, 4, 2}}));
|
Reference in New Issue
Block a user