[BACKEND] Extracting numWarps from tritonGPU module (#39)
This commit is contained in:
@@ -37,12 +37,6 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
|||||||
"mlir::scf::SCFDialect",
|
"mlir::scf::SCFDialect",
|
||||||
"mlir::triton::TritonDialect",
|
"mlir::triton::TritonDialect",
|
||||||
"mlir::triton::gpu::TritonGPUDialect"];
|
"mlir::triton::gpu::TritonGPUDialect"];
|
||||||
|
|
||||||
let options = [
|
|
||||||
Option<"numWarps", "num-warps",
|
|
||||||
"int32_t", /*default*/"4",
|
|
||||||
"number of warps">
|
|
||||||
];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -10,10 +10,16 @@ template <typename T> class OperationPass;
|
|||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
|
||||||
createConvertTritonToTritonGPUPass(int numWarps = 4);
|
|
||||||
|
|
||||||
}
|
// Create the pass with numWarps passed from cl::opt.
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
|
||||||
|
|
||||||
|
// Create the pass with numWarps set explicitly.
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
createConvertTritonToTritonGPUPass(int numWarps);
|
||||||
|
|
||||||
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif
|
#endif
|
@@ -5,8 +5,10 @@
|
|||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
@@ -163,7 +165,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int NumWarps{-1};
|
int NumWarps{0};
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
||||||
@@ -188,6 +190,24 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Extract numWarps information from TritonGPU module, return 0 if failed.
|
||||||
|
// This is a naive implementation, it assumes that all the blocked layout should
|
||||||
|
// have the same numWarps setting in a module, it just find a blocked layout
|
||||||
|
// encoding and return the warpsPerCTA field.
|
||||||
|
int extractNumWarps(mlir::ModuleOp module) {
|
||||||
|
int numWarps{};
|
||||||
|
if (module->hasAttr(AttrNumWarpsName))
|
||||||
|
numWarps = module->getAttr(AttrNumWarpsName)
|
||||||
|
.dyn_cast<IntegerAttr>()
|
||||||
|
.getValue()
|
||||||
|
.getZExtValue();
|
||||||
|
else
|
||||||
|
llvm::report_fatal_error(
|
||||||
|
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||||
|
|
||||||
|
return numWarps;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||||
@@ -200,8 +220,6 @@ class ConvertTritonGPUToLLVM
|
|||||||
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||||
public:
|
public:
|
||||||
ConvertTritonGPUToLLVM() = default;
|
ConvertTritonGPUToLLVM() = default;
|
||||||
// For manually overwrite the numWarps option
|
|
||||||
explicit ConvertTritonGPUToLLVM(int numWarps) { this->numWarps = numWarps; }
|
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
@@ -215,6 +233,8 @@ public:
|
|||||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||||
patterns);
|
patterns);
|
||||||
|
|
||||||
|
int numWarps = extractNumWarps(mod);
|
||||||
|
|
||||||
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps);
|
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||||
|
@@ -5,6 +5,7 @@
|
|||||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||||
|
#include "llvm/ADT/APSInt.h"
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -375,6 +376,8 @@ void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
|||||||
class ConvertTritonToTritonGPU
|
class ConvertTritonToTritonGPU
|
||||||
: public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
: public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||||
public:
|
public:
|
||||||
|
ConvertTritonToTritonGPU() = default;
|
||||||
|
// constructor with some parameters set explicitly.
|
||||||
ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; }
|
ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; }
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
@@ -395,6 +398,13 @@ public:
|
|||||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
||||||
|
auto inti = llvm::APSInt(32, false);
|
||||||
|
auto i32_ty = IntegerType::get(mod->getContext(), 32);
|
||||||
|
|
||||||
|
mod->setAttr(
|
||||||
|
AttrNumWarpsName,
|
||||||
|
IntegerAttr::get(i32_ty, llvm::APInt(32, numWarps.getValue())));
|
||||||
|
|
||||||
// update layouts
|
// update layouts
|
||||||
// broadcast src => multicast, dst => broadcasted
|
// broadcast src => multicast, dst => broadcasted
|
||||||
// if (failed(target.refineLayouts(mod, numWarps)))
|
// if (failed(target.refineLayouts(mod, numWarps)))
|
||||||
@@ -408,3 +418,8 @@ std::unique_ptr<OperationPass<ModuleOp>>
|
|||||||
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
|
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
|
||||||
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps);
|
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::triton::createConvertTritonToTritonGPUPass() {
|
||||||
|
return std::make_unique<::ConvertTritonToTritonGPU>();
|
||||||
|
}
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
// RUN: triton-opt %s -convert-triton-to-tritongpu
|
// RUN: triton-opt %s -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
|
||||||
|
|
||||||
func @ops() {
|
func @ops() {
|
||||||
|
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
|
||||||
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
|
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
|
||||||
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
|
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
|
||||||
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
|
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
|
||||||
|
@@ -1,9 +1,15 @@
|
|||||||
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=num-warps=8
|
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s
|
||||||
|
|
||||||
|
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
|
||||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
|
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
|
||||||
// CHECK: attributes {nvvm.maxntidx = 96 : i32}
|
// Here the 128 comes from the 4 in module attribute multiples 32
|
||||||
|
// CHECK: attributes {nvvm.maxntid = 128 : i32} {{.*}}
|
||||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||||
|
|
||||||
// CHECK: llvm.return
|
// CHECK: llvm.return
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // end module
|
||||||
|
@@ -6,8 +6,11 @@
|
|||||||
// CHECK: !nvvm.annotations
|
// CHECK: !nvvm.annotations
|
||||||
// CHECK: !{void (i64, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128}
|
// CHECK: !{void (i64, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128}
|
||||||
|
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
|
||||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
@@ -5,7 +5,11 @@
|
|||||||
// CHECK: .target sm_80
|
// CHECK: .target sm_80
|
||||||
// CHECK: .address_size 64
|
// CHECK: .address_size 64
|
||||||
|
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
|
||||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user