[Triton-MLIR] fix a tiny bug in coalesce pass (#782)

This commit is contained in:
goostavz
2022-10-17 11:29:55 +08:00
committed by GitHub
parent 5898352f97
commit e948a618b3

View File

@@ -63,9 +63,13 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
auto convertType = getTypeConverter(axisInfo, ptr, numWarps); auto convertType = getTypeConverter(axisInfo, ptr, numWarps);
// convert operands // convert operands
SmallVector<Value, 4> newArgs; SmallVector<Value, 4> newArgs;
for (auto v : op->getOperands()) for (auto v : op->getOperands()) {
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>( if (v.getType().isa<RankedTensorType>())
op->getLoc(), convertType(v.getType()), v)); newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), convertType(v.getType()), v));
else
newArgs.push_back(v);
}
// convert output types // convert output types
SmallVector<Type, 4> newTypes; SmallVector<Type, 4> newTypes;
for (auto t : op->getResultTypes()) { for (auto t : op->getResultTypes()) {