diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index 0f3f055c2..4b3b3e8b6 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -8,6 +8,8 @@ std::unique_ptr createTritonGPUPipelinePass(int numStages = 2); std::unique_ptr createTritonGPUCanonicalizeLoopsPass(); +std::unique_ptr createTritonGPUSwizzlePass(); + std::unique_ptr createTritonGPUCoalescePass(); std::unique_ptr createTritonGPUCombineOpsPass(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index a9d501527..c03ce486c 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -51,6 +51,18 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> { "mlir::triton::TritonDialect"]; } +def TritonGPUSwizzle : Pass<"tritongpu-swizzle", "mlir::ModuleOp"> { + let summary = "swizzle"; + + let description = [{ + Inserts conversions to swizzled layout so as to avoid shared memory bank conflicts. + }]; + + let constructor = "mlir::createTritonGPUSwizzlePass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; +} + def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> { let summary = "canonicalize scf.ForOp ops"; diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index a5c0898d3..9f15374ef 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_dialect_library(TritonGPUTransforms CanonicalizeLoops.cpp Combine.cpp Pipeline.cpp + Swizzle.cpp Verifier.cpp TritonGPUConversion.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/Swizzle.cpp b/lib/Dialect/TritonGPU/Transforms/Swizzle.cpp new file mode 100644 index 000000000..7a9938238 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Swizzle.cpp @@ -0,0 +1,105 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::triton; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +struct SwizzlePass : public TritonGPUSwizzleBase { + SwizzlePass() = default; + + struct SwizzleInfo { + int vec; + int perPhase; + int maxPhase; + }; + + SwizzleInfo getSwizzleMMA(int opIdx, triton::gpu::MmaEncodingAttr retEncoding, + RankedTensorType ty) { + SwizzleInfo noSwizzling = {1, 1, 1}; + int version = retEncoding.getVersion(); + auto tyEncoding = ty.getEncoding().cast(); + auto order = tyEncoding.getOrder(); + // number of rows per phase + int perPhase = 128 / (ty.getShape()[order[0]] * + (ty.getElementType().getIntOrFloatBitWidth() / 8)); + perPhase = std::max(perPhase, 1); + // index of the inner dimension in `order` + int inner = (opIdx == 0) ? 0 : 1; + if (version == 1) { + int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; + // TODO: handle rep (see + // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209) + int vec = 1; + return SwizzleInfo{vec, perPhase, maxPhase}; + } else if (version == 2) { + auto eltTy = ty.getElementType(); + std::vector mat_shape = {8, 8, + 2 * 64 / eltTy.getIntOrFloatBitWidth()}; + // for now, disable swizzle when using transposed int8 tensor cores + bool is_int8_mma = ty.getElementType().isInteger(8); + if (is_int8_mma && order[0] == inner) + return noSwizzling; + // compute swizzling for A operand + if (opIdx == 0) { + int vec = order[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m + int mmaStride = order[0] == 1 ? mat_shape[0] : mat_shape[2]; + int maxPhase = mmaStride / perPhase; + std::cout << perPhase << " " << mat_shape[0] << " " << mat_shape[1] + << " " << mat_shape[2] << std::endl; + return SwizzleInfo{vec, perPhase, maxPhase}; + } + // compute swizzling for B operand + else if (opIdx == 1) { + int vec = order[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k + int mmaStride = order[0] == 1 ? mat_shape[2] : mat_shape[1]; + int maxPhase = mmaStride / perPhase; + return SwizzleInfo{vec, perPhase, maxPhase}; + } else { + llvm_unreachable("invalid operand index"); + } + } else + llvm_unreachable("unsupported swizzling for provided MMA version"); + } + + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *context = &getContext(); + op->walk([&](triton::DotOp dotOp) -> void { + OpBuilder builder(dotOp); + auto _retEncoding = + dotOp.getResult().getType().cast().getEncoding(); + auto retEncoding = _retEncoding.dyn_cast(); + if (!retEncoding) + return; + for (int opIdx : {0, 1}) { + Value op = dotOp.getOperand(opIdx); + auto ty = op.getType().template cast(); + // compute new swizzled encoding + SwizzleInfo swizzle = getSwizzleMMA(opIdx, retEncoding, ty); + auto newEncoding = triton::gpu::SharedEncodingAttr::get( + &getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase, + ty.getEncoding() + .cast() + .getOrder()); + // create conversion + auto newType = RankedTensorType::get(ty.getShape(), ty.getElementType(), + newEncoding); + Operation *newOp = builder.create( + op.getLoc(), newType, op); + // bind new op to dot operand + dotOp->replaceUsesOfWith(op, newOp->getResult(0)); + } + }); + } +}; +} // anonymous namespace + +std::unique_ptr mlir::createTritonGPUSwizzlePass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/test/TritonGPU/swizzle.mlir b/test/TritonGPU/swizzle.mlir new file mode 100644 index 000000000..8fd4d81db --- /dev/null +++ b/test/TritonGPU/swizzle.mlir @@ -0,0 +1,71 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-swizzle | FileCheck %s + +#shared = #triton_gpu.shared<{vec=1, perPhase=1, maxPhase=1 ,order = [1, 0]}> +#mma1w = #triton_gpu.mma<{version=2, warpsPerCTA=[1, 1]}> +#mma2w = #triton_gpu.mma<{version=2, warpsPerCTA=[1, 2]}> +#mma4w = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 2]}> +#mma8w = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 4]}> + +// CHECK: [[shared_v8p1m8:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +// CHECK: [[shared_v8p2m4:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> +// CHECK: [[shared_v8p4m2:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0]}> + +#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#shared3 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0]}> + + +module attributes {"triton_gpu.num-warps" = 8 : i32} { + // CHECK-LABEL: swizzle_mma_f16_128x256x64_w8 + func @swizzle_mma_f16_128x256x64_w8(%A: tensor<128x64xf16, #shared>, %B: tensor<64x256xf16, #shared>) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma8w> + // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]> + // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x256xf16, {{.*}}>) -> tensor<64x256xf16, [[shared_v8p1m8]]> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x64xf16, #shared> * tensor<64x256xf16, #shared> -> tensor<128x256xf32, #mma8w> + return + } +} + + +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: swizzle_mma_f16_128x128x64_w4 + func @swizzle_mma_f16_128x128x64_w4(%A: tensor<128x64xf16, #shared>, %B: tensor<64x128xf16, #shared>) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w> + // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]> + // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x128xf16, {{.*}}>) -> tensor<64x128xf16, [[shared_v8p1m8]]> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x64xf16, #shared> * tensor<64x128xf16, #shared> -> tensor<128x128xf32, #mma4w> + return + } +} + +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: swizzle_mma_f16_128x128x32_w4 + func @swizzle_mma_f16_128x128x32_w4(%A: tensor<128x32xf16, #shared>, %B: tensor<32x128xf16, #shared>) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w> + // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x32xf16, {{.*}}>) -> tensor<128x32xf16, [[shared_v8p2m4]]> + // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x128xf16, {{.*}}>) -> tensor<32x128xf16, [[shared_v8p1m8]]> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma4w> + return + } +} + +module attributes {"triton_gpu.num-warps" = 2 : i32} { + // CHECK-LABEL: swizzle_mma_f16_32x32x32_w2 + func @swizzle_mma_f16_32x32x32_w2(%A: tensor<32x32xf16, #shared>, %B: tensor<32x32xf16, #shared>) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma2w> + // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]> + // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<32x32xf16, #shared> * tensor<32x32xf16, #shared> -> tensor<32x32xf32, #mma2w> + return + } +} + +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: swizzle_mma_f16_16x16x16_w1 + func @swizzle_mma_f16_16x16x16_w1(%A: tensor<16x16xf16, #shared>, %B: tensor<16x16xf16, #shared>) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma1w> + // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]> + // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<16x16xf16, #shared> * tensor<16x16xf16, #shared> -> tensor<16x16xf32, #mma1w> + return + } +}