[OPTIMIZER] Added swizzling pass (#758)
This commit is contained in:
@@ -8,6 +8,8 @@ std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUSwizzlePass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCoalescePass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass();
|
||||
|
@@ -51,6 +51,18 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
||||
"mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUSwizzle : Pass<"tritongpu-swizzle", "mlir::ModuleOp"> {
|
||||
let summary = "swizzle";
|
||||
|
||||
let description = [{
|
||||
Inserts conversions to swizzled layout so as to avoid shared memory bank conflicts.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUSwizzlePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
||||
let summary = "canonicalize scf.ForOp ops";
|
||||
|
||||
|
@@ -7,6 +7,7 @@ add_mlir_dialect_library(TritonGPUTransforms
|
||||
CanonicalizeLoops.cpp
|
||||
Combine.cpp
|
||||
Pipeline.cpp
|
||||
Swizzle.cpp
|
||||
Verifier.cpp
|
||||
TritonGPUConversion.cpp
|
||||
|
||||
|
105
lib/Dialect/TritonGPU/Transforms/Swizzle.cpp
Normal file
105
lib/Dialect/TritonGPU/Transforms/Swizzle.cpp
Normal file
@@ -0,0 +1,105 @@
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
|
||||
SwizzlePass() = default;
|
||||
|
||||
struct SwizzleInfo {
|
||||
int vec;
|
||||
int perPhase;
|
||||
int maxPhase;
|
||||
};
|
||||
|
||||
SwizzleInfo getSwizzleMMA(int opIdx, triton::gpu::MmaEncodingAttr retEncoding,
|
||||
RankedTensorType ty) {
|
||||
SwizzleInfo noSwizzling = {1, 1, 1};
|
||||
int version = retEncoding.getVersion();
|
||||
auto tyEncoding = ty.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto order = tyEncoding.getOrder();
|
||||
// number of rows per phase
|
||||
int perPhase = 128 / (ty.getShape()[order[0]] *
|
||||
(ty.getElementType().getIntOrFloatBitWidth() / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
// index of the inner dimension in `order`
|
||||
int inner = (opIdx == 0) ? 0 : 1;
|
||||
if (version == 1) {
|
||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
// TODO: handle rep (see
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
|
||||
int vec = 1;
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
} else if (version == 2) {
|
||||
auto eltTy = ty.getElementType();
|
||||
std::vector<size_t> mat_shape = {8, 8,
|
||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
bool is_int8_mma = ty.getElementType().isInteger(8);
|
||||
if (is_int8_mma && order[0] == inner)
|
||||
return noSwizzling;
|
||||
// compute swizzling for A operand
|
||||
if (opIdx == 0) {
|
||||
int vec = order[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
|
||||
int mmaStride = order[0] == 1 ? mat_shape[0] : mat_shape[2];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
std::cout << perPhase << " " << mat_shape[0] << " " << mat_shape[1]
|
||||
<< " " << mat_shape[2] << std::endl;
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
}
|
||||
// compute swizzling for B operand
|
||||
else if (opIdx == 1) {
|
||||
int vec = order[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
|
||||
int mmaStride = order[0] == 1 ? mat_shape[2] : mat_shape[1];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
} else {
|
||||
llvm_unreachable("invalid operand index");
|
||||
}
|
||||
} else
|
||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
op->walk([&](triton::DotOp dotOp) -> void {
|
||||
OpBuilder builder(dotOp);
|
||||
auto _retEncoding =
|
||||
dotOp.getResult().getType().cast<RankedTensorType>().getEncoding();
|
||||
auto retEncoding = _retEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (!retEncoding)
|
||||
return;
|
||||
for (int opIdx : {0, 1}) {
|
||||
Value op = dotOp.getOperand(opIdx);
|
||||
auto ty = op.getType().template cast<RankedTensorType>();
|
||||
// compute new swizzled encoding
|
||||
SwizzleInfo swizzle = getSwizzleMMA(opIdx, retEncoding, ty);
|
||||
auto newEncoding = triton::gpu::SharedEncodingAttr::get(
|
||||
&getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase,
|
||||
ty.getEncoding()
|
||||
.cast<triton::gpu::SharedEncodingAttr>()
|
||||
.getOrder());
|
||||
// create conversion
|
||||
auto newType = RankedTensorType::get(ty.getShape(), ty.getElementType(),
|
||||
newEncoding);
|
||||
Operation *newOp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
op.getLoc(), newType, op);
|
||||
// bind new op to dot operand
|
||||
dotOp->replaceUsesOfWith(op, newOp->getResult(0));
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUSwizzlePass() {
|
||||
return std::make_unique<SwizzlePass>();
|
||||
}
|
71
test/TritonGPU/swizzle.mlir
Normal file
71
test/TritonGPU/swizzle.mlir
Normal file
@@ -0,0 +1,71 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-swizzle | FileCheck %s
|
||||
|
||||
#shared = #triton_gpu.shared<{vec=1, perPhase=1, maxPhase=1 ,order = [1, 0]}>
|
||||
#mma1w = #triton_gpu.mma<{version=2, warpsPerCTA=[1, 1]}>
|
||||
#mma2w = #triton_gpu.mma<{version=2, warpsPerCTA=[1, 2]}>
|
||||
#mma4w = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 2]}>
|
||||
#mma8w = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 4]}>
|
||||
|
||||
// CHECK: [[shared_v8p1m8:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
|
||||
// CHECK: [[shared_v8p2m4:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
// CHECK: [[shared_v8p4m2:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0]}>
|
||||
|
||||
#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#shared3 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0]}>
|
||||
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
||||
// CHECK-LABEL: swizzle_mma_f16_128x256x64_w8
|
||||
func @swizzle_mma_f16_128x256x64_w8(%A: tensor<128x64xf16, #shared>, %B: tensor<64x256xf16, #shared>) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma8w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x256xf16, {{.*}}>) -> tensor<64x256xf16, [[shared_v8p1m8]]>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x64xf16, #shared> * tensor<64x256xf16, #shared> -> tensor<128x256xf32, #mma8w>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: swizzle_mma_f16_128x128x64_w4
|
||||
func @swizzle_mma_f16_128x128x64_w4(%A: tensor<128x64xf16, #shared>, %B: tensor<64x128xf16, #shared>) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x128xf16, {{.*}}>) -> tensor<64x128xf16, [[shared_v8p1m8]]>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x64xf16, #shared> * tensor<64x128xf16, #shared> -> tensor<128x128xf32, #mma4w>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: swizzle_mma_f16_128x128x32_w4
|
||||
func @swizzle_mma_f16_128x128x32_w4(%A: tensor<128x32xf16, #shared>, %B: tensor<32x128xf16, #shared>) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x32xf16, {{.*}}>) -> tensor<128x32xf16, [[shared_v8p2m4]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x128xf16, {{.*}}>) -> tensor<32x128xf16, [[shared_v8p1m8]]>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma4w>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: swizzle_mma_f16_32x32x32_w2
|
||||
func @swizzle_mma_f16_32x32x32_w2(%A: tensor<32x32xf16, #shared>, %B: tensor<32x32xf16, #shared>) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma2w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<32x32xf16, #shared> * tensor<32x32xf16, #shared> -> tensor<32x32xf32, #mma2w>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: swizzle_mma_f16_16x16x16_w1
|
||||
func @swizzle_mma_f16_16x16x16_w1(%A: tensor<16x16xf16, #shared>, %B: tensor<16x16xf16, #shared>) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma1w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<16x16xf16, #shared> * tensor<16x16xf16, #shared> -> tensor<16x16xf32, #mma1w>
|
||||
return
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user