[OPTIMIZER] Coalesce pass no longer takes a num-warps
argument (#99)
Improved design to avoid inconsistent `num-warps` value between the pass and the parent module of the operation it processes.
This commit is contained in:
@@ -11,7 +11,8 @@ using namespace mlir::triton;
|
||||
|
||||
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
|
||||
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr) {
|
||||
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr,
|
||||
int numWarps) {
|
||||
auto origType = ptr.getType().cast<RankedTensorType>();
|
||||
// Get the shape of the tensor.
|
||||
size_t rank = origType.getRank();
|
||||
@@ -36,14 +37,13 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
// create encoding
|
||||
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
|
||||
&getContext(), origType.getShape(), sizePerThread, order,
|
||||
this->numWarps);
|
||||
&getContext(), origType.getShape(), sizePerThread, order, numWarps);
|
||||
return encoding;
|
||||
}
|
||||
|
||||
std::function<Type(Type)> getTypeConverter(AxisInfoAnalysis &axisInfo,
|
||||
Value ptr) {
|
||||
Attribute encoding = getCoalescedEncoding(axisInfo, ptr);
|
||||
Value ptr, int numWarps) {
|
||||
Attribute encoding = getCoalescedEncoding(axisInfo, ptr, numWarps);
|
||||
return [encoding](Type _type) {
|
||||
RankedTensorType type = _type.cast<RankedTensorType>();
|
||||
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
||||
@@ -57,8 +57,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
|
||||
if (!ty)
|
||||
return;
|
||||
auto mod = op->getParentOfType<ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
|
||||
auto convertType = getTypeConverter(axisInfo, ptr);
|
||||
auto convertType = getTypeConverter(axisInfo, ptr, numWarps);
|
||||
// convert operands
|
||||
SmallVector<Value, 4> newArgs;
|
||||
for (auto v : op->getOperands())
|
||||
|
Reference in New Issue
Block a user