[OPTIMIZER] Rewrite patterns for layout conversions (#64)

This commit is contained in:
Philippe Tillet
2022-08-18 12:49:37 -07:00
committed by GitHub
parent e0bedeb44c
commit 192be76b3c
19 changed files with 851 additions and 127 deletions

View File

@@ -32,8 +32,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
unsigned alignment = std::min(maxMultiple, maxContig);
unsigned perThread = std::min(alignment, 128 / numBits);
sizePerThread[order[0]] = perThread;
SmallVector<unsigned> dims(rank);
std::iota(dims.begin(), dims.end(), 0);
// create encoding
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
&getContext(), origType.getShape(), sizePerThread, order,
this->numWarps);
return encoding;
@@ -64,15 +66,20 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
op->getLoc(), convertType(v.getType()), v));
// convert output types
SmallVector<Type, 4> newTypes;
for (auto t : op->getResultTypes())
newTypes.push_back(convertType(t));
for (auto t : op->getResultTypes()) {
bool is_async = std::is_same<T, triton::gpu::CopyAsyncOp>::value;
newTypes.push_back(is_async ? t : convertType(t));
}
// construct new op with the new encoding
Operation *newOp =
builder.create<T>(op->getLoc(), newTypes, newArgs, op->getAttrs());
// cast the results back to the original layout
for (size_t i = 0; i < op->getNumResults(); i++) {
auto newResult = builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), op->getResult(i).getType(), newOp->getResult(i));
Value newResult = newOp->getResult(i);
if (newTypes[i] != op->getResultTypes()[i]) {
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), op->getResult(i).getType(), newResult);
}
op->getResult(i).replaceAllUsesWith(newResult);
}
op->erase();
@@ -97,6 +104,9 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
builder.setInsertionPoint(curr);
if (auto load = dyn_cast<triton::LoadOp>(curr))
coalesceOp<triton::LoadOp>(axisInfo, curr, load.ptr(), builder);
if (auto load = dyn_cast<triton::gpu::CopyAsyncOp>(curr))
coalesceOp<triton::gpu::CopyAsyncOp>(axisInfo, curr, load.ptr(),
builder);
if (auto store = dyn_cast<triton::StoreOp>(curr))
coalesceOp<triton::StoreOp>(axisInfo, curr, store.ptr(), builder);
});