[OPTIMIZER] Coalesce pass no longer takes a num-warps
argument (#99)
Improved design to avoid inconsistent `num-warps` value between the pass and the parent module of the operation it processes.
This commit is contained in:
@@ -795,24 +795,6 @@ private:
|
||||
AxisInfoAnalysis &AxisAnalysisPass;
|
||||
};
|
||||
|
||||
// Extract numWarps information from TritonGPU module, return 0 if failed.
|
||||
// This is a naive implementation, it assumes that all the blocked layout should
|
||||
// have the same numWarps setting in a module, it just find a blocked layout
|
||||
// encoding and return the warpsPerCTA field.
|
||||
int extractNumWarps(mlir::ModuleOp module) {
|
||||
int numWarps{};
|
||||
if (module->hasAttr(AttrNumWarpsName))
|
||||
numWarps = module->getAttr(AttrNumWarpsName)
|
||||
.dyn_cast<IntegerAttr>()
|
||||
.getValue()
|
||||
.getZExtValue();
|
||||
else
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||
|
||||
return numWarps;
|
||||
}
|
||||
|
||||
struct BroadcastOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::BroadcastOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
@@ -1319,7 +1301,7 @@ public:
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
int numWarps = extractNumWarps(mod);
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
auto axisAnalysis = runAxisAnalysis(mod);
|
||||
|
||||
|
Reference in New Issue
Block a user