[OPTIMIZER] Rewrite patterns for layout conversions (#64)
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements loop software pipelining
|
||||
@@ -168,8 +167,7 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
if (auto tensorType = convertLayout.getResult()
|
||||
.getType()
|
||||
.dyn_cast<RankedTensorType>()) {
|
||||
if (tensorType.getEncoding()
|
||||
.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
if (tensorType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
isCandiate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
}
|
||||
@@ -263,7 +261,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
// assert(I1 or TensorOf<[I1]>);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPoint(newOp);
|
||||
Value splatCond = builder.create<triton::BroadcastOp>(
|
||||
Value splatCond = builder.create<triton::SplatOp>(
|
||||
mask.getLoc(), mask.getType(), loopCond);
|
||||
Value newMask =
|
||||
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
|
||||
@@ -356,6 +354,9 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
Value loadUse = load.getUsers().begin()->getResult(0);
|
||||
mapping.lookup(loadUse).replaceAllUsesWith(
|
||||
newForOp.getRegionIterArgs()[loadIdx + idx * (numStages - 1)]);
|
||||
// delete old load and layout conversion
|
||||
mapping.lookup(loadUse).getDefiningOp()->erase();
|
||||
mapping.lookup(load).getDefiningOp()->erase();
|
||||
}
|
||||
|
||||
// 4. prefetch the next iteration
|
||||
@@ -389,7 +390,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
if (loads.contains(op->getResult(0))) {
|
||||
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
||||
Value mask = loadOp.mask();
|
||||
Value splatCond = builder.create<triton::BroadcastOp>(
|
||||
Value splatCond = builder.create<triton::SplatOp>(
|
||||
mask.getLoc(), mask.getType(), nextLoopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(
|
||||
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
||||
@@ -442,9 +443,10 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
yieldValues.push_back(
|
||||
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
|
||||
yieldValues.push_back(nextIV);
|
||||
|
||||
builder.setInsertionPointToEnd(newForOp.getBody());
|
||||
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
|
||||
yieldValues);
|
||||
auto test = builder.create<scf::YieldOp>(
|
||||
forOp.getBody()->getTerminator()->getLoc(), yieldValues);
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user