[OPTIMIZER] Rewrite patterns for layout conversions (#64)
This commit is contained in:
@@ -32,8 +32,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
unsigned perThread = std::min(alignment, 128 / numBits);
|
||||
sizePerThread[order[0]] = perThread;
|
||||
SmallVector<unsigned> dims(rank);
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
// create encoding
|
||||
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
|
||||
&getContext(), origType.getShape(), sizePerThread, order,
|
||||
this->numWarps);
|
||||
return encoding;
|
||||
@@ -64,15 +66,20 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
op->getLoc(), convertType(v.getType()), v));
|
||||
// convert output types
|
||||
SmallVector<Type, 4> newTypes;
|
||||
for (auto t : op->getResultTypes())
|
||||
newTypes.push_back(convertType(t));
|
||||
for (auto t : op->getResultTypes()) {
|
||||
bool is_async = std::is_same<T, triton::gpu::CopyAsyncOp>::value;
|
||||
newTypes.push_back(is_async ? t : convertType(t));
|
||||
}
|
||||
// construct new op with the new encoding
|
||||
Operation *newOp =
|
||||
builder.create<T>(op->getLoc(), newTypes, newArgs, op->getAttrs());
|
||||
// cast the results back to the original layout
|
||||
for (size_t i = 0; i < op->getNumResults(); i++) {
|
||||
auto newResult = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), op->getResult(i).getType(), newOp->getResult(i));
|
||||
Value newResult = newOp->getResult(i);
|
||||
if (newTypes[i] != op->getResultTypes()[i]) {
|
||||
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), op->getResult(i).getType(), newResult);
|
||||
}
|
||||
op->getResult(i).replaceAllUsesWith(newResult);
|
||||
}
|
||||
op->erase();
|
||||
@@ -97,6 +104,9 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
builder.setInsertionPoint(curr);
|
||||
if (auto load = dyn_cast<triton::LoadOp>(curr))
|
||||
coalesceOp<triton::LoadOp>(axisInfo, curr, load.ptr(), builder);
|
||||
if (auto load = dyn_cast<triton::gpu::CopyAsyncOp>(curr))
|
||||
coalesceOp<triton::gpu::CopyAsyncOp>(axisInfo, curr, load.ptr(),
|
||||
builder);
|
||||
if (auto store = dyn_cast<triton::StoreOp>(curr))
|
||||
coalesceOp<triton::StoreOp>(axisInfo, curr, store.ptr(), builder);
|
||||
});
|
||||
|
Reference in New Issue
Block a user