[Triton-MLIR] fix a tiny bug in coalesce pass (#782)
This commit is contained in:
@@ -63,9 +63,13 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
auto convertType = getTypeConverter(axisInfo, ptr, numWarps);
|
||||
// convert operands
|
||||
SmallVector<Value, 4> newArgs;
|
||||
for (auto v : op->getOperands())
|
||||
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), convertType(v.getType()), v));
|
||||
for (auto v : op->getOperands()) {
|
||||
if (v.getType().isa<RankedTensorType>())
|
||||
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), convertType(v.getType()), v));
|
||||
else
|
||||
newArgs.push_back(v);
|
||||
}
|
||||
// convert output types
|
||||
SmallVector<Type, 4> newTypes;
|
||||
for (auto t : op->getResultTypes()) {
|
||||
|
Reference in New Issue
Block a user