[TritonIR] Make mask operand optional (#74)
This commit is contained in:
@@ -49,58 +49,43 @@ namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
//-- StoreOp --
|
||||
// Default mask
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value value) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
auto shape = ptrType.getShape();
|
||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true));
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(value);
|
||||
state.addOperands(mask);
|
||||
StoreOp::build(builder, state, ptr, value, mlir::Value());
|
||||
}
|
||||
|
||||
//-- LoadOp --
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
Type elementType =
|
||||
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||
auto shape = ptrType.getShape();
|
||||
// mask
|
||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true));
|
||||
Type resultType = RankedTensorType::get(shape, elementType);
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(mask);
|
||||
state.addAttribute(
|
||||
cacheAttrName(state.name),
|
||||
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||
state.addAttribute(
|
||||
evictAttrName(state.name),
|
||||
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
||||
state.addAttribute(isVolatileAttrName(state.name),
|
||||
builder.getBoolAttr(isVolatile));
|
||||
state.addTypes({resultType});
|
||||
LoadOp::build(builder, state, ptr, mlir::Value(), mlir::Value(), cache, evict,
|
||||
isVolatile);
|
||||
}
|
||||
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value mask,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
LoadOp::build(builder, state, ptr, mask, mlir::Value(), cache, evict,
|
||||
isVolatile);
|
||||
}
|
||||
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
Type elementType =
|
||||
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||
auto shape = ptrType.getShape();
|
||||
Type resultType = RankedTensorType::get(shape, elementType);
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(mask);
|
||||
if (mask) {
|
||||
state.addOperands(mask);
|
||||
if (other) {
|
||||
state.addOperands(other);
|
||||
}
|
||||
}
|
||||
state.addAttribute(
|
||||
cacheAttrName(state.name),
|
||||
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||
|
@@ -52,6 +52,46 @@ DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// select(cond, load(ptrs, broadcast(cond), ???), other)
|
||||
// => load(ptrs, broadcast(cond), other)
|
||||
class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineSelectMaskedLoadPattern(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::SelectOp::getOperationName(), 3, context,
|
||||
{triton::LoadOp::getOperationName()}) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto selectOp = llvm::dyn_cast<mlir::SelectOp>(op);
|
||||
if (!selectOp)
|
||||
return mlir::failure();
|
||||
|
||||
mlir::Value trueValue = selectOp.getTrueValue();
|
||||
mlir::Value falseValue = selectOp.getFalseValue();
|
||||
|
||||
auto *loadOpCandidate = trueValue.getDefiningOp();
|
||||
auto loadOp = llvm::dyn_cast<triton::LoadOp>(loadOpCandidate);
|
||||
if (!loadOp)
|
||||
return mlir::failure();
|
||||
|
||||
mlir::Value mask = loadOp.mask();
|
||||
if (!mask)
|
||||
return mlir::failure();
|
||||
|
||||
auto *broadcastOpCandidate = mask.getDefiningOp();
|
||||
auto broadcastOp =
|
||||
llvm::dyn_cast<triton::BroadcastOp>(broadcastOpCandidate);
|
||||
if (!broadcastOp)
|
||||
return mlir::failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, loadOp.ptr(), loadOp.mask(), falseValue, loadOp.cache(),
|
||||
loadOp.evict(), loadOp.isVolatile());
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||
|
||||
|
@@ -37,12 +37,6 @@ def CombineGEPPattern : Pat<
|
||||
(TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1),
|
||||
(TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
|
||||
|
||||
// select(cond, load(ptrs, broadcast(cond), ???), other)
|
||||
// => load(ptrs, broadcast(cond), other)
|
||||
def CombineSelectMaskedLoadPattern : Pat<
|
||||
(SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile), $falseValue),
|
||||
(TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile)>;
|
||||
|
||||
// broadcast(cst) => cst
|
||||
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
|
||||
def CombineBroadcastConstantPattern : Pat<
|
||||
|
Reference in New Issue
Block a user