[BACKEND] Extracting numWarps from tritonGPU module (#39)
This commit is contained in:
@@ -5,8 +5,10 @@
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
@@ -163,7 +165,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
||||
}
|
||||
|
||||
private:
|
||||
int NumWarps{-1};
|
||||
int NumWarps{0};
|
||||
};
|
||||
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
||||
@@ -188,6 +190,24 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Extract numWarps information from TritonGPU module, return 0 if failed.
|
||||
// This is a naive implementation, it assumes that all the blocked layout should
|
||||
// have the same numWarps setting in a module, it just find a blocked layout
|
||||
// encoding and return the warpsPerCTA field.
|
||||
int extractNumWarps(mlir::ModuleOp module) {
|
||||
int numWarps{};
|
||||
if (module->hasAttr(AttrNumWarpsName))
|
||||
numWarps = module->getAttr(AttrNumWarpsName)
|
||||
.dyn_cast<IntegerAttr>()
|
||||
.getValue()
|
||||
.getZExtValue();
|
||||
else
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||
|
||||
return numWarps;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
@@ -200,8 +220,6 @@ class ConvertTritonGPUToLLVM
|
||||
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||
public:
|
||||
ConvertTritonGPUToLLVM() = default;
|
||||
// For manually overwrite the numWarps option
|
||||
explicit ConvertTritonGPUToLLVM(int numWarps) { this->numWarps = numWarps; }
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
@@ -215,6 +233,8 @@ public:
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
|
||||
int numWarps = extractNumWarps(mod);
|
||||
|
||||
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps);
|
||||
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "llvm/ADT/APSInt.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
@@ -375,6 +376,8 @@ void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
class ConvertTritonToTritonGPU
|
||||
: public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||
public:
|
||||
ConvertTritonToTritonGPU() = default;
|
||||
// constructor with some parameters set explicitly.
|
||||
ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; }
|
||||
|
||||
void runOnOperation() override {
|
||||
@@ -395,6 +398,13 @@ public:
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
auto inti = llvm::APSInt(32, false);
|
||||
auto i32_ty = IntegerType::get(mod->getContext(), 32);
|
||||
|
||||
mod->setAttr(
|
||||
AttrNumWarpsName,
|
||||
IntegerAttr::get(i32_ty, llvm::APInt(32, numWarps.getValue())));
|
||||
|
||||
// update layouts
|
||||
// broadcast src => multicast, dst => broadcasted
|
||||
// if (failed(target.refineLayouts(mod, numWarps)))
|
||||
@@ -408,3 +418,8 @@ std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
|
||||
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::triton::createConvertTritonToTritonGPUPass() {
|
||||
return std::make_unique<::ConvertTritonToTritonGPU>();
|
||||
}
|
||||
|
Reference in New Issue
Block a user