[TritonIR] make other optional and remove isOtherUnspecified (#67)

[Triton] make other optional and remove isOtherUnspecified
This commit is contained in:
Shintaro Iwasaki
2022-08-18 18:19:55 -07:00
committed by GitHub
parent 192be76b3c
commit 9aa00249a6
14 changed files with 69 additions and 58 deletions

View File

@@ -76,14 +76,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true));
// other
Type resultType = RankedTensorType::get(shape, elementType);
::mlir::Value other = builder.create<arith::ConstantOp>(
ptr.getLoc(), resultType,
DenseElementsAttr::get(resultType, builder.getZeroAttr(elementType)));
state.addOperands(ptr);
state.addOperands(mask);
state.addOperands(other);
state.addAttribute(
cacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
@@ -92,8 +87,28 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name),
builder.getBoolAttr(isVolatile));
state.addAttribute(isOtherUnspecifiedAttrName(state.name),
builder.getBoolAttr(false));
state.addTypes({resultType});
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
Type elementType =
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
auto shape = ptrType.getShape();
Type resultType = RankedTensorType::get(shape, elementType);
state.addOperands(ptr);
state.addOperands(mask);
state.addAttribute(
cacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
state.addAttribute(
evictAttrName(state.name),
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name),
builder.getBoolAttr(isVolatile));
state.addTypes({resultType});
}

View File

@@ -40,8 +40,8 @@ def CombineGEPPattern : Pat<
// select(cond, load(ptrs, broadcast(cond), ???), other)
// => load(ptrs, broadcast(cond), other)
def CombineSelectMaskedLoadPattern : Pat<
(SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile, $isOtherUnspecified), $falseValue),
(TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile, $isOtherUnspecified)>;
(SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile), $falseValue),
(TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile)>;
// broadcast(cst) => cst
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;

View File

@@ -240,7 +240,7 @@ void LoopPipeliner::emitPrologue() {
newOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(),
loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(),
loadOp.isVolatile(), loadOp.isOtherUnspecified());
loadOp.isVolatile());
} else
llvm_unreachable("This should be LoadOp");
} else
@@ -404,7 +404,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
nextMapping.lookupOrDefault(loadOp.ptr()),
nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile(), loadOp.isOtherUnspecified());
loadOp.evict(), loadOp.isVolatile());
} else
nextOp = builder.clone(*op, nextMapping);
// llvm::errs() << "epilogue cloning...: " << *op << "\n";