[TritonIR] Make mask operand optional (#74)
This commit is contained in:
@@ -391,14 +391,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
if (loads.contains(op->getResult(0))) {
|
||||
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
||||
Value mask = loadOp.mask();
|
||||
Value splatCond = builder.create<triton::SplatOp>(
|
||||
mask.getLoc(), mask.getType(), nextLoopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(
|
||||
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
||||
// if mask is defined outside the loop, don't update the map more than
|
||||
// once
|
||||
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
||||
nextMapping.map(mask, newMask);
|
||||
if (mask) {
|
||||
Value splatCond = builder.create<triton::SplatOp>(
|
||||
mask.getLoc(), mask.getType(), nextLoopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(
|
||||
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
||||
// if mask is defined outside the loop, don't update the map more than
|
||||
// once
|
||||
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
||||
nextMapping.map(mask, newMask);
|
||||
}
|
||||
// TODO: more elegant way to do this?
|
||||
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
|
||||
|
Reference in New Issue
Block a user