[OPTIMIZER] Coalesce pass no longer takes a num-warps argument (#99)

Improved design to avoid inconsistent `num-warps` value between the pass and the parent module of the operation it processes.
This commit is contained in:
Philippe Tillet
2022-09-05 18:09:02 -07:00
committed by GitHub
parent ea175f689e
commit a0bab9748e
5 changed files with 26 additions and 31 deletions

View File

@@ -795,24 +795,6 @@ private:
AxisInfoAnalysis &AxisAnalysisPass;
};
// Extract numWarps information from TritonGPU module, return 0 if failed.
// This is a naive implementation, it assumes that all the blocked layout should
// have the same numWarps setting in a module, it just find a blocked layout
// encoding and return the warpsPerCTA field.
int extractNumWarps(mlir::ModuleOp module) {
int numWarps{};
if (module->hasAttr(AttrNumWarpsName))
numWarps = module->getAttr(AttrNumWarpsName)
.dyn_cast<IntegerAttr>()
.getValue()
.getZExtValue();
else
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return numWarps;
}
struct BroadcastOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::BroadcastOp> {
using ConvertTritonGPUOpToLLVMPattern<
@@ -1319,7 +1301,7 @@ public:
RewritePatternSet patterns(context);
int numWarps = extractNumWarps(mod);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
auto axisAnalysis = runAxisAnalysis(mod);

View File

@@ -11,7 +11,8 @@ using namespace mlir::triton;
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr) {
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr,
int numWarps) {
auto origType = ptr.getType().cast<RankedTensorType>();
// Get the shape of the tensor.
size_t rank = origType.getRank();
@@ -36,14 +37,13 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
std::iota(dims.begin(), dims.end(), 0);
// create encoding
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
&getContext(), origType.getShape(), sizePerThread, order,
this->numWarps);
&getContext(), origType.getShape(), sizePerThread, order, numWarps);
return encoding;
}
std::function<Type(Type)> getTypeConverter(AxisInfoAnalysis &axisInfo,
Value ptr) {
Attribute encoding = getCoalescedEncoding(axisInfo, ptr);
Value ptr, int numWarps) {
Attribute encoding = getCoalescedEncoding(axisInfo, ptr, numWarps);
return [encoding](Type _type) {
RankedTensorType type = _type.cast<RankedTensorType>();
return RankedTensorType::get(type.getShape(), type.getElementType(),
@@ -57,8 +57,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
if (!ty)
return;
auto mod = op->getParentOfType<ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
auto convertType = getTypeConverter(axisInfo, ptr);
auto convertType = getTypeConverter(axisInfo, ptr, numWarps);
// convert operands
SmallVector<Value, 4> newArgs;
for (auto v : op->getOperands())