[TritonIR] simplify Load/StoreOps when mask is true/false (#79)
* [TritonIR] fix Load/Store/CopyAsyncOp's parsers * [TritonIR] simplify Load/StoreOps when mask is true/false * [TEST] adds tests to check load/store simplification
This commit is contained in:
@@ -94,7 +94,12 @@ def TT_LoadOp : TT_Op<"load",
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||
// let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||
let parser = [{ return mlir::triton::parseLoadOp(parser, result); }];
|
||||
|
||||
let printer = [{ return mlir::triton::printLoadOp(p, *this); }];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TT_StoreOp : TT_Op<"store",
|
||||
@@ -114,7 +119,12 @@ def TT_StoreOp : TT_Op<"store",
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` type($value)";
|
||||
// let assemblyFormat = "operands attr-dict `:` type($value)";
|
||||
let parser = [{ return mlir::triton::parseStoreOp(parser, result); }];
|
||||
|
||||
let printer = [{ return mlir::triton::printStoreOp(p, *this); }];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TT_GEPOp : TT_Op<"getelementptr",
|
||||
|
@@ -60,7 +60,10 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` type($ptr) `->` type($result)";
|
||||
// let assemblyFormat = "operands attr-dict `:` type($ptr) `->` type($result)";
|
||||
let parser = [{ return parseCopyAsyncOp(parser, result); }];
|
||||
|
||||
let printer = [{ return printCopyAsyncOp(p, *this); }];
|
||||
|
||||
// result needs to be of shared layout
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
@@ -36,6 +36,68 @@ static Type getPointerTypeFromTensor(Type type) {
|
||||
return Type();
|
||||
}
|
||||
|
||||
// Parser & printer for assembly forms
|
||||
ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
||||
Type resultTypes[1];
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.parseCustomTypeWithFallback(resultTypes[0]))
|
||||
return failure();
|
||||
|
||||
result.addTypes(resultTypes);
|
||||
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(getPointerTypeFromTensor(resultTypes[0])); // ptr
|
||||
if (allOperands.size() >= 2)
|
||||
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
|
||||
if (allOperands.size() >= 3)
|
||||
operandTypes.push_back(resultTypes[0]); // other
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) {
|
||||
printer << " ";
|
||||
printer << loadOp.getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(loadOp->getAttrs(), /*elidedAttrs=*/{});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(loadOp.result().getType());
|
||||
}
|
||||
|
||||
ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
||||
Type valueType;
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.parseCustomTypeWithFallback(valueType))
|
||||
return failure();
|
||||
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(getPointerTypeFromTensor(valueType)); // ptr
|
||||
operandTypes.push_back(valueType); // value
|
||||
if (allOperands.size() >= 3)
|
||||
operandTypes.push_back(getI1SameShape(valueType)); // mask
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
|
||||
printer << " ";
|
||||
printer << storeOp.getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(storeOp->getAttrs(), /*elidedAttrs=*/{});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(storeOp.value().getType());
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
|
@@ -92,6 +92,91 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// load(ptr, splat(1), ...) -> load(ptr, ...)
|
||||
// load(ptr, splat(0), other, ...) -> other
|
||||
struct CanonicalizeMaskedLoadPattern
|
||||
: public mlir::OpRewritePattern<triton::LoadOp> {
|
||||
CanonicalizeMaskedLoadPattern(mlir::MLIRContext *context)
|
||||
: OpRewritePattern<triton::LoadOp>(context, 1) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(triton::LoadOp loadOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto mask = loadOp.mask();
|
||||
if (!mask)
|
||||
return mlir::failure();
|
||||
|
||||
auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
|
||||
if (!constantMask)
|
||||
return mlir::failure();
|
||||
|
||||
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if (!splatMask)
|
||||
return mlir::failure();
|
||||
|
||||
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
|
||||
// mask = splat(1)
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
loadOp, loadOp.getType(), loadOp.ptr(), Value(), Value(),
|
||||
loadOp.cache(), loadOp.evict(), loadOp.isVolatile());
|
||||
} else {
|
||||
// mask = splat(0)
|
||||
|
||||
// If there's no "other", the value is "undef". Perhaps we want to
|
||||
// optimize it in the future.x
|
||||
auto otherVal = loadOp.other();
|
||||
if (!otherVal)
|
||||
return mlir::failure();
|
||||
rewriter.replaceOp(loadOp, otherVal);
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void triton::LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<CanonicalizeMaskedLoadPattern>(context);
|
||||
}
|
||||
|
||||
// store(ptr, value, splat(1), ...) -> store(ptr, value, ...)
|
||||
// store(ptr, value, splat(0), ...) -> [none]
|
||||
struct CanonicalizeMaskedStorePattern
|
||||
: public mlir::OpRewritePattern<triton::StoreOp> {
|
||||
CanonicalizeMaskedStorePattern(mlir::MLIRContext *context)
|
||||
: OpRewritePattern<triton::StoreOp>(context, 1) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(triton::StoreOp storeOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto mask = storeOp.mask();
|
||||
if (!mask)
|
||||
return mlir::failure();
|
||||
|
||||
auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
|
||||
if (!constantMask)
|
||||
return mlir::failure();
|
||||
|
||||
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if (!splatMask)
|
||||
return mlir::failure();
|
||||
|
||||
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
|
||||
// mask = splat(1)
|
||||
rewriter.replaceOpWithNewOp<triton::StoreOp>(storeOp, storeOp.ptr(),
|
||||
storeOp.value());
|
||||
} else {
|
||||
// mask = splat(0)
|
||||
rewriter.eraseOp(storeOp);
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<CanonicalizeMaskedStorePattern>(context);
|
||||
}
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||
|
||||
|
@@ -10,6 +10,37 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton::gpu;
|
||||
|
||||
// Utility
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// Type inference
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
}
|
||||
|
||||
static Type getPointeeType(Type type) {
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
// Tensor of pointers
|
||||
auto shape = tensorType.getShape();
|
||||
auto ptrType = tensorType.getElementType().dyn_cast<PointerType>();
|
||||
Type pointeeType = ptrType.getPointeeType();
|
||||
return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding());
|
||||
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
|
||||
// scalar pointer
|
||||
Type pointeeType = ptrType.getPointeeType();
|
||||
return pointeeType;
|
||||
}
|
||||
return Type();
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
static LogicalResult parseIntAttrValue(AsmParser &parser, const Attribute &attr,
|
||||
unsigned &value, StringRef desc) {
|
||||
auto intAttr = attr.dyn_cast<IntegerAttr>();
|
||||
@@ -260,6 +291,44 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CopyAsyncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult parseCopyAsyncOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
||||
Type resultTypes[1], ptrType;
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.parseCustomTypeWithFallback(ptrType) || parser.parseArrow() ||
|
||||
parser.parseCustomTypeWithFallback(resultTypes[0]))
|
||||
return failure();
|
||||
result.addTypes(resultTypes);
|
||||
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(ptrType); // ptr
|
||||
if (allOperands.size() >= 2)
|
||||
operandTypes.push_back(triton::getI1SameShape(ptrType)); // mask
|
||||
if (allOperands.size() >= 3)
|
||||
operandTypes.push_back(triton::getPointeeType(ptrType)); // other
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void printCopyAsyncOp(OpAsmPrinter &printer, CopyAsyncOp copyAsyncOp) {
|
||||
printer << " ";
|
||||
printer << copyAsyncOp.getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(copyAsyncOp->getAttrs(), /*elidedAttrs=*/{});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(copyAsyncOp.ptr().getType());
|
||||
printer << " -> ";
|
||||
printer.printStrippedAttrOrType(copyAsyncOp.result().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ASM Interface (i.e.: alias)
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -282,8 +351,7 @@ public:
|
||||
os << "slice";
|
||||
return AliasResult::FinalAlias;
|
||||
} */
|
||||
OpAsmDialectInterface::getAlias(attr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
return OpAsmDialectInterface::getAlias(attr, os);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -299,36 +367,6 @@ void TritonGPUDialect::initialize() {
|
||||
addInterfaces<TritonGPUOpAsmInterface>();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// Type inference
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
}
|
||||
|
||||
static Type getPointeeType(Type type) {
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
// Tensor of pointers
|
||||
auto shape = tensorType.getShape();
|
||||
auto ptrType = tensorType.getElementType().dyn_cast<PointerType>();
|
||||
Type pointeeType = ptrType.getPointeeType();
|
||||
return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding());
|
||||
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
|
||||
// scalar pointer
|
||||
Type pointeeType = ptrType.getPointeeType();
|
||||
return pointeeType;
|
||||
}
|
||||
return Type();
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Verification
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -352,4 +390,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) {
|
||||
// TODO: fill this.
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
@@ -70,3 +70,38 @@ func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
|
||||
// CHECK-NEXT: return %[[cst]] : tensor<8x2xf32>
|
||||
return %bst_out : tensor<8x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_load_pattern
|
||||
func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
||||
%true_mask = arith.constant dense<true> : tensor<8xi1>
|
||||
%false_mask = arith.constant dense<false> : tensor<8xi1>
|
||||
%other_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||
|
||||
// true_mask with other
|
||||
// CHECK: %[[res1:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
%x = tt.load %ptr, %true_mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
|
||||
// true_mask without other
|
||||
// CHECK: %[[res2:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
%y = tt.load %ptr, %true_mask, %other_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
|
||||
// false_mask with other. It should become "other" (i.e., %y)
|
||||
%z = tt.load %ptr, %false_mask, %y {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
|
||||
// CHECK: return %[[res1]], %[[res2]], %[[res2]] : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
||||
return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_store_pattern
|
||||
func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
|
||||
%true_mask = arith.constant dense<true> : tensor<8xi1>
|
||||
%false_mask = arith.constant dense<false> : tensor<8xi1>
|
||||
|
||||
// CHECK: tt.store %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
tt.store %ptr, %val, %true_mask : tensor<8xf32>
|
||||
|
||||
// The following store should disappear.
|
||||
// CHECK-NEXT: return
|
||||
tt.store %ptr, %val, %false_mask : tensor<8xf32>
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user