From a0bab9748efdcf7f0e93af1bfb7191bae6fca47f Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 5 Sep 2022 18:09:02 -0700 Subject: [PATCH] [OPTIMIZER] Coalesce pass no longer takes a `num-warps` argument (#99) Improved design to avoid inconsistent `num-warps` value between the pass and the parent module of the operation it processes. --- .../Dialect/TritonGPU/IR/TritonGPUDialect.td | 12 +++++++++++ .../Dialect/TritonGPU/Transforms/Passes.td | 6 ------ .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 20 +------------------ lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 15 ++++++++------ test/TritonGPU/coalesce.mlir | 4 ++++ 5 files changed, 26 insertions(+), 31 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td index 2c43e22fd..d875f3c60 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -18,6 +18,18 @@ def TritonGPU_Dialect : Dialect { "triton::TritonDialect", "mlir::gpu::GPUDialect" ]; + + let extraClassDeclaration = [{ + static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; } + static int getNumWarps(ModuleOp mod) { + if(!mod->hasAttr("triton_gpu.num-warps")) + llvm::report_fatal_error( + "TritonGPU module should contain a triton_gpu.num-warps attribute"); + return mod->getAttr("triton_gpu.num-warps").cast().getInt(); + } + }]; + + } #endif diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 62afedfdd..22a34f43d 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -33,12 +33,6 @@ def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> { let constructor = "mlir::createTritonGPUCoalescePass()"; let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; - - let options = [ - Option<"numWarps", "num-warps", - "int32_t", /*default*/"4", - "number of warps"> - ]; } def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> { diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 9c6ca2488..aba9e71d6 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -795,24 +795,6 @@ private: AxisInfoAnalysis &AxisAnalysisPass; }; -// Extract numWarps information from TritonGPU module, return 0 if failed. -// This is a naive implementation, it assumes that all the blocked layout should -// have the same numWarps setting in a module, it just find a blocked layout -// encoding and return the warpsPerCTA field. -int extractNumWarps(mlir::ModuleOp module) { - int numWarps{}; - if (module->hasAttr(AttrNumWarpsName)) - numWarps = module->getAttr(AttrNumWarpsName) - .dyn_cast() - .getValue() - .getZExtValue(); - else - llvm::report_fatal_error( - "TritonGPU module should contain a triton_gpu.num-warps attribute"); - - return numWarps; -} - struct BroadcastOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< @@ -1319,7 +1301,7 @@ public: RewritePatternSet patterns(context); - int numWarps = extractNumWarps(mod); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); auto axisAnalysis = runAxisAnalysis(mod); diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index e2dc9a09d..63a0a6d1b 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -11,7 +11,8 @@ using namespace mlir::triton; struct CoalescePass : public TritonGPUCoalesceBase { - Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr) { + Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr, + int numWarps) { auto origType = ptr.getType().cast(); // Get the shape of the tensor. size_t rank = origType.getRank(); @@ -36,14 +37,13 @@ struct CoalescePass : public TritonGPUCoalesceBase { std::iota(dims.begin(), dims.end(), 0); // create encoding Attribute encoding = triton::gpu::BlockedEncodingAttr::get( - &getContext(), origType.getShape(), sizePerThread, order, - this->numWarps); + &getContext(), origType.getShape(), sizePerThread, order, numWarps); return encoding; } std::function getTypeConverter(AxisInfoAnalysis &axisInfo, - Value ptr) { - Attribute encoding = getCoalescedEncoding(axisInfo, ptr); + Value ptr, int numWarps) { + Attribute encoding = getCoalescedEncoding(axisInfo, ptr, numWarps); return [encoding](Type _type) { RankedTensorType type = _type.cast(); return RankedTensorType::get(type.getShape(), type.getElementType(), @@ -57,8 +57,11 @@ struct CoalescePass : public TritonGPUCoalesceBase { RankedTensorType ty = ptr.getType().template dyn_cast(); if (!ty) return; + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue(); - auto convertType = getTypeConverter(axisInfo, ptr); + auto convertType = getTypeConverter(axisInfo, ptr, numWarps); // convert operands SmallVector newArgs; for (auto v : op->getOperands()) diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index f2ac1f65c..d6f5c8527 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -4,6 +4,8 @@ #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> // CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> @@ -44,3 +46,5 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1> return } + +} \ No newline at end of file