diff --git a/include/triton/Conversion/Passes.td b/include/triton/Conversion/Passes.td index c365e0f02..0f37c76a7 100644 --- a/include/triton/Conversion/Passes.td +++ b/include/triton/Conversion/Passes.td @@ -37,12 +37,6 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" "mlir::scf::SCFDialect", "mlir::triton::TritonDialect", "mlir::triton::gpu::TritonGPUDialect"]; - - let options = [ - Option<"numWarps", "num-warps", - "int32_t", /*default*/"4", - "number of warps"> - ]; } #endif diff --git a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h index bdb058249..a51101f4a 100644 --- a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h +++ b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h @@ -10,10 +10,16 @@ template class OperationPass; namespace triton { -std::unique_ptr> -createConvertTritonToTritonGPUPass(int numWarps = 4); +constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps"; -} +// Create the pass with numWarps passed from cl::opt. +std::unique_ptr> createConvertTritonToTritonGPUPass(); + +// Create the pass with numWarps set explicitly. +std::unique_ptr> +createConvertTritonToTritonGPUPass(int numWarps); + +} // namespace triton } // namespace mlir -#endif \ No newline at end of file +#endif diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 78e51b962..a39f08885 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -5,8 +5,10 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include using namespace mlir; using namespace mlir::triton; @@ -163,7 +165,7 @@ struct FuncOpConversion : public FuncOpConversionBase { } private: - int NumWarps{-1}; + int NumWarps{0}; }; struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> { @@ -188,6 +190,24 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> { } }; +// Extract numWarps information from TritonGPU module, return 0 if failed. +// This is a naive implementation, it assumes that all the blocked layout should +// have the same numWarps setting in a module, it just find a blocked layout +// encoding and return the warpsPerCTA field. +int extractNumWarps(mlir::ModuleOp module) { + int numWarps{}; + if (module->hasAttr(AttrNumWarpsName)) + numWarps = module->getAttr(AttrNumWarpsName) + .dyn_cast() + .getValue() + .getZExtValue(); + else + llvm::report_fatal_error( + "TritonGPU module should contain a triton_gpu.num-warps attribute"); + + return numWarps; +} + } // namespace void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, @@ -200,8 +220,6 @@ class ConvertTritonGPUToLLVM : public ConvertTritonGPUToLLVMBase { public: ConvertTritonGPUToLLVM() = default; - // For manually overwrite the numWarps option - explicit ConvertTritonGPUToLLVM(int numWarps) { this->numWarps = numWarps; } void runOnOperation() override { MLIRContext *context = &getContext(); @@ -215,6 +233,8 @@ public: mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); + int numWarps = extractNumWarps(mod); + populateTritonToLLVMPatterns(typeConverter, patterns, numWarps); if (failed(applyPartialConversion(mod, target, std::move(patterns)))) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index a8a4ba807..4bddf5649 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -5,6 +5,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "llvm/ADT/APSInt.h" #include using namespace mlir; @@ -375,6 +376,8 @@ void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, class ConvertTritonToTritonGPU : public ConvertTritonToTritonGPUBase { public: + ConvertTritonToTritonGPU() = default; + // constructor with some parameters set explicitly. ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; } void runOnOperation() override { @@ -395,6 +398,13 @@ public: if (failed(applyPartialConversion(mod, target, std::move(patterns)))) return signalPassFailure(); + auto inti = llvm::APSInt(32, false); + auto i32_ty = IntegerType::get(mod->getContext(), 32); + + mod->setAttr( + AttrNumWarpsName, + IntegerAttr::get(i32_ty, llvm::APInt(32, numWarps.getValue()))); + // update layouts // broadcast src => multicast, dst => broadcasted // if (failed(target.refineLayouts(mod, numWarps))) @@ -408,3 +418,8 @@ std::unique_ptr> mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) { return std::make_unique<::ConvertTritonToTritonGPU>(numWarps); } + +std::unique_ptr> +mlir::triton::createConvertTritonToTritonGPUPass() { + return std::make_unique<::ConvertTritonToTritonGPU>(); +} diff --git a/test/Conversion/ops.mlir b/test/Conversion/ops.mlir index a76f326ce..bd04b60dd 100644 --- a/test/Conversion/ops.mlir +++ b/test/Conversion/ops.mlir @@ -1,9 +1,10 @@ -// RUN: triton-opt %s -convert-triton-to-tritongpu +// RUN: triton-opt %s -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s func @ops() { +// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}} %a = arith.constant dense<1.00e+00> : tensor<128x32xf16> %b = arith.constant dense<2.00e+00> : tensor<32x128xf16> %c = arith.constant dense<3.00e+00> : tensor<128x128xf32> %0 = tt.dot %a, %b, %c {allowTF32 = true} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> return -} \ No newline at end of file +} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 33cea027e..8f006097b 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1,9 +1,15 @@ -// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=num-warps=8 +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s + + +module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr) -// CHECK: attributes {nvvm.maxntidx = 96 : i32} +// Here the 128 comes from the 4 in module attribute multiples 32 +// CHECK: attributes {nvvm.maxntid = 128 : i32} {{.*}} func @test_empty_kernel(%lb : index, %A : !tt.ptr) { // CHECK: llvm.return return } + +} // end module diff --git a/test/Target/tritongpu_to_llvmir.mlir b/test/Target/tritongpu_to_llvmir.mlir index 70291ec2e..96712c112 100644 --- a/test/Target/tritongpu_to_llvmir.mlir +++ b/test/Target/tritongpu_to_llvmir.mlir @@ -6,8 +6,11 @@ // CHECK: !nvvm.annotations // CHECK: !{void (i64, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128} +module attributes {"triton_gpu.num-warps" = 4 : i32} { + func @test_empty_kernel(%lb : index, %A : !tt.ptr) { return } +} diff --git a/test/Target/tritongpu_to_ptx.mlir b/test/Target/tritongpu_to_ptx.mlir index b09600065..1fa6d85bc 100644 --- a/test/Target/tritongpu_to_ptx.mlir +++ b/test/Target/tritongpu_to_ptx.mlir @@ -5,7 +5,11 @@ // CHECK: .target sm_80 // CHECK: .address_size 64 +module attributes {"triton_gpu.num-warps" = 4 : i32} { + func @test_empty_kernel(%lb : index, %A : !tt.ptr) { return } + +}