[TritonIR] Make mask operand optional (#74)

This commit is contained in:
Shintaro Iwasaki
2022-08-22 22:00:17 -07:00
committed by GitHub
parent de2dd04c8a
commit 0ebef11c77
14 changed files with 113 additions and 102 deletions

View File

@@ -391,14 +391,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
if (loads.contains(op->getResult(0))) {
auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask();
Value splatCond = builder.create<triton::SplatOp>(
mask.getLoc(), mask.getType(), nextLoopCond);
Value newMask = builder.create<arith::AndIOp>(
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
if (mask) {
Value splatCond = builder.create<triton::SplatOp>(
mask.getLoc(), mask.getType(), nextLoopCond);
Value newMask = builder.create<arith::AndIOp>(
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
}
// TODO: more elegant way to do this?
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), loadsMapping[op->getResult(0)].getType(),