diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 788522ba5..51203e0b9 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -71,6 +71,70 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order ); + let builders = [ + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "Type":$eltTy), [{ + auto mmaEnc = dotOpEnc.getParent().dyn_cast(); + // Only support row major for now + // TODO(Keren): check why column major code crashes + SmallVector order = {1, 0}; + + if(!mmaEnc) + return $_get(context, 1, 1, 1, order); + + int version = mmaEnc.getVersion(); + int opIdx = dotOpEnc.getOpIdx(); + + // number of rows per phase + int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8)); + perPhase = std::max(perPhase, 1); + + // index of the inner dimension in `order` + unsigned inner = (opIdx == 0) ? 0 : 1; + + // ---- begin version 1 ---- + // TODO: handle rep (see + // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209) + if (version == 1) { + int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; + return $_get(context, 1, perPhase, maxPhase, order); + } + + // ---- begin version 2 ---- + if (version == 2) { + std::vector matShape = {8, 8, + 2 * 64 / eltTy.getIntOrFloatBitWidth()}; + // for now, disable swizzle when using transposed int8 tensor cores + if (eltTy.isInteger(8) && order[0] == inner) + return $_get(context, 1, 1, 1, order); + + // --- handle A operand --- + if (opIdx == 0) { // compute swizzling for A operand + int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m + int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2]; + int maxPhase = mmaStride / perPhase; + return $_get(context, vec, perPhase, maxPhase, order); + } + + // --- handle B operand --- + if (opIdx == 1) { + int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k + int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1]; + int maxPhase = mmaStride / perPhase; + return $_get(context, vec, perPhase, maxPhase, order); + } + + llvm_unreachable("invalid operand index"); + } + + // ---- not implemented ---- + llvm_unreachable("unsupported swizzling for provided MMA version"); + + + }]> + ]; + let extraClassDeclaration = extraBaseClassDeclaration; } diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index f70350d89..e570d60d5 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -11,8 +11,6 @@ std::unique_ptr createTritonGPUPrefetchPass(); std::unique_ptr createTritonGPUCanonicalizeLoopsPass(); -std::unique_ptr createTritonGPUSwizzlePass(); - std::unique_ptr createTritonGPUCoalescePass(); std::unique_ptr createTritonGPUCombineOpsPass(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 8f3f0f32f..caa85a950 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -65,18 +65,6 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> { "mlir::triton::TritonDialect"]; } -def TritonGPUSwizzle : Pass<"tritongpu-swizzle", "mlir::ModuleOp"> { - let summary = "swizzle"; - - let description = [{ - Inserts conversions to swizzled layout so as to avoid shared memory bank conflicts. - }]; - - let constructor = "mlir::createTritonGPUSwizzlePass()"; - - let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; -} - def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> { let summary = "canonicalize scf.ForOp ops"; diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 0e21b925d..c77a29c21 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2707,6 +2707,7 @@ public: return lowerDistributedToDistributed(op, adaptor, rewriter); } // TODO: to be implemented + llvm_unreachable("unsupported layout conversion"); return failure(); } @@ -5763,6 +5764,35 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, class ConvertTritonGPUToLLVM : public ConvertTritonGPUToLLVMBase { + +private: + void decomposeBlockedToDotOperand(ModuleOp mod) { + // replace `blocked -> dot_op` with `blocked -> shared -> dot_op` + // because the codegen doesn't handle `blocked -> dot_op` directly + mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cvtOp.getOperand().getType().cast(); + auto dstType = cvtOp.getType().cast(); + auto srcBlocked = + srcType.getEncoding().dyn_cast(); + auto dstDotOp = + dstType.getEncoding().dyn_cast(); + if (srcBlocked && dstDotOp) { + auto tmpType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::SharedEncodingAttr::get(mod.getContext(), dstDotOp, + srcType.getShape(), + srcType.getElementType())); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getOperand()); + auto newConvert = builder.create( + cvtOp.getLoc(), dstType, tmp); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + } + }); + } + public: ConvertTritonGPUToLLVM() = default; @@ -5779,15 +5809,19 @@ public: int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); - // step 1: Allocate shared memories and insert barriers - // step 2: Convert SCF to CFG - // step 3: Convert FuncOp to LLVMFuncOp via partial conversion - // step 4: Convert the rest of ops via partial conversion + // step 1: Decompose unoptimized layout conversions to use shared memory + // step 2: Allocate shared memories and insert barriers + // step 3: Convert SCF to CFG + // step 4: Convert FuncOp to LLVMFuncOp via partial conversion + // step 5: Convert the rest of ops via partial conversion // The reason for putting step 1 before step 2 is that the membar analysis // currently only supports SCF but not CFG. // The reason for a separation between 1/4 is that, step 3 is out of // the scope of Dialect Conversion, thus we need to make sure the smem // is not revised during the conversion of step 4. + + decomposeBlockedToDotOperand(mod); + Allocation allocation(mod); MembarAnalysis membar(&allocation); diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 6f440df5d..aabcc1901 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -8,7 +8,6 @@ add_mlir_dialect_library(TritonGPUTransforms Combine.cpp Pipeline.cpp Prefetch.cpp - Swizzle.cpp TritonGPUConversion.cpp DEPENDS diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index aa27c1aad..0e88c1fad 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -74,10 +74,6 @@ class LoopPipeliner { /// returns a empty buffer of size ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder); - /// compute type of shared buffers (with swizzled shared layouts) - RankedTensorType getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc, - RankedTensorType tensorType); - public: LoopPipeliner(scf::ForOp forOp, int numStages) : forOp(forOp), numStages(numStages) { @@ -148,70 +144,6 @@ ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op, llvm_unreachable("Async copy's return should be of RankedTensorType"); } -// TODO: I copied the code from Swizzle.cpp. Should find a way to unify the -// code path. -// Swizzle has to be performed before pipeline for now. If we do swizzle -// after pipeline, we need to propagate the swizzled layout to all -// operands that is an alias of the swizzled tensor. The alias analysis -// component maybe helpful for this purpose. -RankedTensorType -LoopPipeliner::getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc, - RankedTensorType ty) { - int opIdx = dotOpEnc.getOpIdx(); - int vec = 1; - int maxPhase = 1; - int perPhase = 1; - llvm::SmallVector order; - if (auto mmaEnc = dotOpEnc.getParent().dyn_cast()) { - // Only support row major for now - // TODO(Keren): check why column major code crashes - order = {1, 0}; - int version = mmaEnc.getVersion(); - auto tyEncoding = ty.getEncoding().cast(); - // number of rows per phase - perPhase = 128 / (ty.getShape()[order[0]] * - (ty.getElementType().getIntOrFloatBitWidth() / 8)); - perPhase = std::max(perPhase, 1); - - // index of the inner dimension in `order` - unsigned inner = (opIdx == 0) ? 0 : 1; - if (version == 1) { - maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; - // TODO: handle rep (see - // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209) - } else if (version == 2) { - auto eltTy = ty.getElementType(); - std::vector matShape = {8, 8, - 2 * 64 / eltTy.getIntOrFloatBitWidth()}; - // for now, disable swizzle when using transposed int8 tensor cores - if (ty.getElementType().isInteger(8) && order[0] == inner) - perPhase = 1; - else { - if (opIdx == 0) { // compute swizzling for A operand - vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m - int mmaStride = order[0] == 1 ? matShape[0] : matShape[2]; - maxPhase = mmaStride / perPhase; - } else if (opIdx == 1) { // compute swizzling for B operand - vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k - int mmaStride = order[0] == 1 ? matShape[2] : matShape[1]; - maxPhase = mmaStride / perPhase; - } else - llvm_unreachable("invalid operand index"); - } - } else // version not in [1, 2] - llvm_unreachable("unsupported swizzling for provided MMA version"); - } else { // If the layout of dot is not mma, we don't need to swizzle - auto blockedEnc = dotOpEnc.getParent().cast(); - order = llvm::SmallVector(blockedEnc.getOrder().begin(), - blockedEnc.getOrder().end()); - } - auto newEncoding = ttg::SharedEncodingAttr::get(ty.getContext(), vec, - perPhase, maxPhase, order); - SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); - bufferShape.insert(bufferShape.begin(), numStages); - return RankedTensorType::get(bufferShape, ty.getElementType(), newEncoding); -} - /// A load instruction can be pipelined if: /// - the load doesn't depend on any other loads (after loop peeling) /// - (?) this load is not a loop-invariant value (we should run LICM before @@ -264,8 +196,14 @@ LogicalResult LoopPipeliner::initialize() { .dyn_cast()) { isCandiate = true; loadsMapping[loadOp] = convertLayout; - loadsBufferType[loadOp] = getSwizzleType( - dotOpEnc, loadOp.getType().cast()); + auto ty = loadOp.getType().cast(); + SmallVector bufferShape(ty.getShape().begin(), + ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), numStages); + auto sharedEnc = ttg::SharedEncodingAttr::get( + ty.getContext(), dotOpEnc, ty.getShape(), ty.getElementType()); + loadsBufferType[loadOp] = RankedTensorType::get( + bufferShape, ty.getElementType(), sharedEnc); } } } diff --git a/lib/Dialect/TritonGPU/Transforms/Swizzle.cpp b/lib/Dialect/TritonGPU/Transforms/Swizzle.cpp deleted file mode 100644 index a519e32db..000000000 --- a/lib/Dialect/TritonGPU/Transforms/Swizzle.cpp +++ /dev/null @@ -1,134 +0,0 @@ -#include "mlir/Analysis/SliceAnalysis.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Passes.h" - -using namespace mlir; -using namespace mlir::triton; - -#define GEN_PASS_CLASSES -#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" - -namespace { - -struct SwizzlePass : public TritonGPUSwizzleBase { - SwizzlePass() = default; - - struct SwizzleInfo { - int vec; - int perPhase; - int maxPhase; - }; - - SwizzleInfo getSwizzleMMA(int opIdx, triton::gpu::MmaEncodingAttr retEncoding, - RankedTensorType ty) { - SwizzleInfo noSwizzling = {1, 1, 1}; - int version = retEncoding.getVersion(); - auto tyEncoding = ty.getEncoding().cast(); - auto order = tyEncoding.getOrder(); - // number of rows per phase - int perPhase = 128 / (ty.getShape()[order[0]] * - (ty.getElementType().getIntOrFloatBitWidth() / 8)); - perPhase = std::max(perPhase, 1); - // index of the inner dimension in `order` - size_t inner = (opIdx == 0) ? 0 : 1; - if (version == 1) { - int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; - // TODO: handle rep (see - // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209) - int vec = 1; - return SwizzleInfo{vec, perPhase, maxPhase}; - } else if (version == 2) { - auto eltTy = ty.getElementType(); - std::vector matShape = {8, 8, - 2 * 64 / eltTy.getIntOrFloatBitWidth()}; - // for now, disable swizzle when using transposed int8 tensor cores - bool isInt8Mma = ty.getElementType().isInteger(8); - if (isInt8Mma && order[0] == inner) - return noSwizzling; - // compute swizzling for A operand - if (opIdx == 0) { - int vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m - int mmaStride = order[0] == 1 ? matShape[0] : matShape[2]; - int maxPhase = mmaStride / perPhase; - return SwizzleInfo{vec, perPhase, maxPhase}; - } - // compute swizzling for B operand - else if (opIdx == 1) { - int vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k - int mmaStride = order[0] == 1 ? matShape[2] : matShape[1]; - int maxPhase = mmaStride / perPhase; - return SwizzleInfo{vec, perPhase, maxPhase}; - } else { - llvm_unreachable("invalid operand index"); - } - } else - llvm_unreachable("unsupported swizzling for provided MMA version"); - } - - void runOnOperation() override { - Operation *op = getOperation(); - // replace blocked -> dot_op with - // blocked -> shared -> dot_op in order to - // expose opportunities for swizzling - op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { - OpBuilder builder(cvtOp); - auto srcType = cvtOp.getOperand().getType().cast(); - auto dstType = cvtOp.getType().cast(); - if (srcType.getEncoding().isa() && - dstType.getEncoding().isa()) { - auto tmpType = - RankedTensorType::get(dstType.getShape(), dstType.getElementType(), - triton::gpu::SharedEncodingAttr::get( - op->getContext(), 1, 1, 1, {1, 0})); - auto tmp = builder.create( - cvtOp.getLoc(), tmpType, cvtOp.getOperand()); - auto newConvert = builder.create( - cvtOp.getLoc(), dstType, tmp); - cvtOp.replaceAllUsesWith(newConvert.getResult()); - } - }); - - op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { - OpBuilder builder(cvtOp); - auto arg = cvtOp.getOperand(); - auto retType = cvtOp.getResult().getType().cast(); - auto retEncoding = - retType.getEncoding().dyn_cast(); - auto argType = arg.getType().cast(); - auto argEncoding = - argType.getEncoding().dyn_cast(); - if (!argEncoding || !retEncoding) - return; - auto opIdx = retEncoding.getOpIdx(); - // compute new swizzled encoding - auto parentEncoding = - retEncoding.getParent().dyn_cast(); - if (!parentEncoding) - return; - auto swizzleType = argType; - if (arg.getDefiningOp() && - isa(arg.getDefiningOp())) { - swizzleType = arg.getDefiningOp() - ->getOperand(0) - .getType() - .cast(); - } - SwizzleInfo swizzle = getSwizzleMMA(opIdx, parentEncoding, swizzleType); - auto newEncoding = triton::gpu::SharedEncodingAttr::get( - &getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase, - argEncoding.getOrder()); - // create conversion - auto newType = RankedTensorType::get( - argType.getShape(), argType.getElementType(), newEncoding); - Operation *newArg = builder.create( - cvtOp.getLoc(), newType, arg); - // bind new op to cvt operand - cvtOp->replaceUsesOfWith(arg, newArg->getResult(0)); - }); - } -}; -} // anonymous namespace - -std::unique_ptr mlir::createTritonGPUSwizzlePass() { - return std::make_unique(); -} \ No newline at end of file diff --git a/python/src/triton.cc b/python/src/triton.cc index c04cc9da6..31925c1f6 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1264,10 +1264,6 @@ void init_triton_ir(py::module &&m) { [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUCombineOpsPass()); }) - .def("add_triton_gpu_swizzle_pass", - [](mlir::PassManager &self) { - self.addPass(mlir::createTritonGPUSwizzlePass()); - }) .def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) { self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass()); diff --git a/python/triton/compiler.py b/python/triton/compiler.py index c1dbbfd16..5da966e61 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -885,7 +885,6 @@ def ttir_to_ttgir(mod, num_warps, num_stages): pm.add_coalesce_pass() pm.add_triton_gpu_combine_pass() pm.add_licm_pass() - pm.add_triton_gpu_swizzle_pass() pm.add_triton_gpu_combine_pass() pm.add_cse_pass() pm.run(mod) diff --git a/test/TritonGPU/swizzle.mlir b/test/TritonGPU/swizzle.mlir deleted file mode 100644 index 256b4f1b9..000000000 --- a/test/TritonGPU/swizzle.mlir +++ /dev/null @@ -1,90 +0,0 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-swizzle | FileCheck %s - -#shared = #triton_gpu.shared<{vec=1, perPhase=1, maxPhase=1 ,order = [1, 0]}> -#mma1w = #triton_gpu.mma<{version=2, warpsPerCTA=[1, 1]}> -#mma2w = #triton_gpu.mma<{version=2, warpsPerCTA=[1, 2]}> -#mma4w = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 2]}> -#mma8w = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 4]}> - -// CHECK: [[shared_v8p1m8:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> -// CHECK: [[shared_v8p2m4:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> -// CHECK: [[shared_v8p4m2:#.*]] = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0]}> - -#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#shared3 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0]}> - -#mma1w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma1w}> -#mma1w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma1w}> -#mma2w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma2w}> -#mma2w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma2w}> -#mma4w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma4w}> -#mma4w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma4w}> -#mma8w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma8w}> -#mma8w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma8w}> - - -module attributes {"triton_gpu.num-warps" = 8 : i32} { - // CHECK-LABEL: swizzle_mma_f16_128x256x64_w8 - func @swizzle_mma_f16_128x256x64_w8(%A_SMEM: tensor<128x64xf16, #shared>, %B_SMEM: tensor<64x256xf16, #shared>) { - %cst0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma8w> - // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]> - // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x256xf16, {{.*}}>) -> tensor<64x256xf16, [[shared_v8p1m8]]> - %A = triton_gpu.convert_layout %A_SMEM : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #mma8w_op0> - %B = triton_gpu.convert_layout %B_SMEM : (tensor<64x256xf16, #shared>) -> tensor<64x256xf16, #mma8w_op1> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #mma8w_op0> * tensor<64x256xf16, #mma8w_op1> -> tensor<128x256xf32, #mma8w> - return - } -} - - -module attributes {"triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: swizzle_mma_f16_128x128x64_w4 - func @swizzle_mma_f16_128x128x64_w4(%A_SMEM: tensor<128x64xf16, #shared>, %B_SMEM: tensor<64x128xf16, #shared>) { - %cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w> - // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]> - // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x128xf16, {{.*}}>) -> tensor<64x128xf16, [[shared_v8p1m8]]> - %A = triton_gpu.convert_layout %A_SMEM : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #mma4w_op0> - %B = triton_gpu.convert_layout %B_SMEM : (tensor<64x128xf16, #shared>) -> tensor<64x128xf16, #mma4w_op1> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #mma4w_op0> * tensor<64x128xf16, #mma4w_op1> -> tensor<128x128xf32, #mma4w> - return - } -} - -module attributes {"triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: swizzle_mma_f16_128x128x32_w4 - func @swizzle_mma_f16_128x128x32_w4(%A_SMEM: tensor<128x32xf16, #shared>, %B_SMEM: tensor<32x128xf16, #shared>) { - %cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w> - // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x32xf16, {{.*}}>) -> tensor<128x32xf16, [[shared_v8p2m4]]> - // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x128xf16, {{.*}}>) -> tensor<32x128xf16, [[shared_v8p1m8]]> - %A = triton_gpu.convert_layout %A_SMEM : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #mma4w_op0> - %B = triton_gpu.convert_layout %B_SMEM : (tensor<32x128xf16, #shared>) -> tensor<32x128xf16, #mma4w_op1> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #mma4w_op0> * tensor<32x128xf16, #mma4w_op1> -> tensor<128x128xf32, #mma4w> - return - } -} - -module attributes {"triton_gpu.num-warps" = 2 : i32} { - // CHECK-LABEL: swizzle_mma_f16_32x32x32_w2 - func @swizzle_mma_f16_32x32x32_w2(%A_SMEM: tensor<32x32xf16, #shared>, %B_SMEM: tensor<32x32xf16, #shared>) { - %cst0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma2w> - // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]> - // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]> - %A = triton_gpu.convert_layout %A_SMEM : (tensor<32x32xf16, #shared>) -> tensor<32x32xf16, #mma2w_op0> - %B = triton_gpu.convert_layout %B_SMEM : (tensor<32x32xf16, #shared>) -> tensor<32x32xf16, #mma2w_op1> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<32x32xf16, #mma2w_op0> * tensor<32x32xf16, #mma2w_op1> -> tensor<32x32xf32, #mma2w> - return - } -} - -module attributes {"triton_gpu.num-warps" = 1 : i32} { - // CHECK-LABEL: swizzle_mma_f16_16x16x16_w1 - func @swizzle_mma_f16_16x16x16_w1(%A_SMEM: tensor<16x16xf16, #shared>, %B_SMEM: tensor<16x16xf16, #shared>) { - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma1w> - // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]> - // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]> - %A = triton_gpu.convert_layout %A_SMEM : (tensor<16x16xf16, #shared>) -> tensor<16x16xf16, #mma1w_op0> - %B = triton_gpu.convert_layout %B_SMEM : (tensor<16x16xf16, #shared>) -> tensor<16x16xf16, #mma1w_op1> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #mma1w_op0> * tensor<16x16xf16, #mma1w_op1> -> tensor<16x16xf32, #mma1w> - return - } -} diff --git a/unittest/CMakeLists.txt b/unittest/CMakeLists.txt index 0f5feb7b4..dfc47a66e 100644 --- a/unittest/CMakeLists.txt +++ b/unittest/CMakeLists.txt @@ -26,3 +26,4 @@ endfunction() add_subdirectory(Analysis) add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/unittest/Dialect/CMakeLists.txt b/unittest/Dialect/CMakeLists.txt new file mode 100644 index 000000000..eba47a67c --- /dev/null +++ b/unittest/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TritonGPU) diff --git a/unittest/Dialect/TritonGPU/CMakeLists.txt b/unittest/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 000000000..43f9e6801 --- /dev/null +++ b/unittest/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,6 @@ + +add_triton_ut( + NAME TestSwizzling + SRCS SwizzleTest.cpp + LIBS TritonGPUIR ${dialect_libs} ${conversion_libs} +) \ No newline at end of file diff --git a/unittest/Dialect/TritonGPU/SwizzleTest.cpp b/unittest/Dialect/TritonGPU/SwizzleTest.cpp new file mode 100644 index 000000000..ea2109552 --- /dev/null +++ b/unittest/Dialect/TritonGPU/SwizzleTest.cpp @@ -0,0 +1,52 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include + +using namespace mlir; +using mlir::triton::gpu::SharedEncodingAttr; + +struct swizzleParams { + int vec; + int perPhase; + int maxPhase; +}; + +struct ParamT { + std::array shape; + int opIdx; + int typeWidth; + swizzleParams refSwizzle; +}; + +class SwizzleDotOperandTestFixture : public ::testing::TestWithParam { +protected: + ParamType param; +}; + +TEST_P(SwizzleDotOperandTestFixture, DotOperands) { + auto params = GetParam(); + // init context + MLIRContext ctx; + ctx.loadDialect(); + // create encoding + auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, {1, 1}); + auto encoding = + triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent); + + // create element type + Type eltType = IntegerType::get(&ctx, params.typeWidth); + auto layout = SharedEncodingAttr::get(&ctx, encoding, params.shape, eltType); + + ASSERT_EQ(layout.getVec(), params.refSwizzle.vec); + ASSERT_EQ(layout.getPerPhase(), params.refSwizzle.perPhase); + ASSERT_EQ(layout.getMaxPhase(), params.refSwizzle.maxPhase); +} + +INSTANTIATE_TEST_SUITE_P(TestDotOperands, SwizzleDotOperandTestFixture, + ::testing::Values(ParamT{{128, 64}, 0, 16, {8, 1, 8}}, + ParamT{{64, 256}, 1, 16, {8, 1, 8}}, + ParamT{{128, 32}, 0, 16, {8, 2, 4}}, + ParamT{{32, 128}, 1, 16, {8, 1, 8}}, + ParamT{{32, 32}, 0, 16, {8, 2, 4}}, + ParamT{{32, 32}, 1, 16, {8, 2, 4}}, + ParamT{{16, 16}, 0, 16, {8, 4, 2}}, + ParamT{{16, 16}, 1, 16, {8, 4, 2}})); \ No newline at end of file