[OPTIMIZER] Fix the load-mask issue with the pipeline pass (#857)
This commit is contained in:
@@ -15,6 +15,14 @@ using namespace mlir;
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
static Type getI1SameShape(Value v) {
|
||||
Type vType = v.getType();
|
||||
auto i1Type = IntegerType::get(vType.getContext(), 1);
|
||||
auto tensorType = vType.cast<RankedTensorType>();
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
}
|
||||
|
||||
namespace {
|
||||
class LoopPipeliner {
|
||||
/// cache forOp we are working on
|
||||
@@ -262,13 +270,23 @@ void LoopPipeliner::emitPrologue() {
|
||||
loadStageBuffer[op->getResult(0)] = {loadsBuffer[op->getResult(0)]};
|
||||
}
|
||||
// load => copy async
|
||||
// TODO: check if the hardware supports async copy
|
||||
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
||||
Value mask = lookupOrDefault(loadOp.mask(), stage);
|
||||
Value newMask;
|
||||
if (mask) {
|
||||
Value splatCond = builder.create<triton::SplatOp>(
|
||||
mask.getLoc(), mask.getType(), loopCond);
|
||||
newMask =
|
||||
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
|
||||
} else {
|
||||
newMask = builder.create<triton::SplatOp>(
|
||||
loopCond.getLoc(), getI1SameShape(loadOp), loopCond);
|
||||
}
|
||||
// TODO: check if the hardware supports async copy
|
||||
newOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
||||
op->getLoc(), loadsBuffer[loadOp].getType(),
|
||||
lookupOrDefault(loadOp.ptr(), stage),
|
||||
loadStageBuffer[loadOp][stage], pipelineIterIdx,
|
||||
lookupOrDefault(loadOp.mask(), stage),
|
||||
loadStageBuffer[loadOp][stage], pipelineIterIdx, newMask,
|
||||
lookupOrDefault(loadOp.other(), stage), loadOp.cache(),
|
||||
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
||||
loadStageBuffer[loadOp].push_back(newOp->getResult(0));
|
||||
@@ -287,33 +305,6 @@ void LoopPipeliner::emitPrologue() {
|
||||
}
|
||||
}
|
||||
|
||||
// If this is a load/async_copy, we need to update the mask
|
||||
if (Value mask = [&]() {
|
||||
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(newOp)) {
|
||||
return loadOp.mask();
|
||||
} else if (auto insertSliceAsyncOp =
|
||||
llvm::dyn_cast<triton::gpu::InsertSliceAsyncOp>(
|
||||
newOp)) {
|
||||
return insertSliceAsyncOp.mask();
|
||||
} else {
|
||||
return mlir::Value();
|
||||
}
|
||||
}()) {
|
||||
// assert(I1 or TensorOf<[I1]>);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
// TODO: move this out of the loop
|
||||
builder.setInsertionPoint(newOp);
|
||||
Value splatCond = builder.create<triton::SplatOp>(
|
||||
mask.getLoc(), mask.getType(), loopCond);
|
||||
Value newMask =
|
||||
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
|
||||
// TODO: better way to do this?
|
||||
if (llvm::isa<triton::LoadOp>(newOp))
|
||||
newOp->setOperand(1, newMask);
|
||||
else // InsertSliceAsyncOp
|
||||
newOp->setOperand(3, newMask);
|
||||
}
|
||||
|
||||
// update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
Value originalResult = op->getResult(dstIdx);
|
||||
@@ -332,7 +323,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
newOp->getResult(dstIdx), stage + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // for (Operation *op : orderedDeps)
|
||||
|
||||
pipelineIterIdx = builder.create<arith::AddIOp>(
|
||||
iv.getLoc(), pipelineIterIdx,
|
||||
@@ -490,26 +481,29 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
|
||||
for (Operation *op : orderedDeps) {
|
||||
Operation *nextOp = nullptr;
|
||||
// TODO(da): does this work if loadOp has no mask?
|
||||
// update loading mask
|
||||
if (loads.contains(op->getResult(0))) {
|
||||
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
||||
Value mask = loadOp.mask();
|
||||
Value newMask;
|
||||
if (mask) {
|
||||
Value splatCond = builder.create<triton::SplatOp>(
|
||||
mask.getLoc(), mask.getType(), nextLoopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(
|
||||
newMask = builder.create<arith::AndIOp>(
|
||||
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
||||
// if mask is defined outside the loop, don't update the map more than
|
||||
// once
|
||||
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
||||
nextMapping.map(mask, newMask);
|
||||
}
|
||||
newMask = nextMapping.lookupOrDefault(loadOp.mask());
|
||||
} else
|
||||
newMask = builder.create<triton::SplatOp>(
|
||||
loadOp.getLoc(), getI1SameShape(loadOp), nextLoopCond);
|
||||
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
||||
op->getLoc(), loadsBuffer[loadOp].getType(),
|
||||
nextMapping.lookupOrDefault(loadOp.ptr()),
|
||||
newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()],
|
||||
insertSliceIndex, nextMapping.lookupOrDefault(loadOp.mask()),
|
||||
insertSliceIndex, newMask,
|
||||
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
||||
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
||||
nextBuffers.push_back(insertAsyncOp);
|
||||
|
Reference in New Issue
Block a user