From 84aa7d025a52dfd9ee94d8cf5b3b5cfde1bb1e65 Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Wed, 24 Aug 2022 12:55:49 -0700 Subject: [PATCH] [TritonIR] simplify Load/StoreOps when mask is true/false (#79) * [TritonIR] fix Load/Store/CopyAsyncOp's parsers * [TritonIR] simplify Load/StoreOps when mask is true/false * [TEST] adds tests to check load/store simplification --- include/triton/Dialect/Triton/IR/TritonOps.td | 14 ++- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 5 +- lib/Dialect/Triton/IR/Ops.cpp | 62 +++++++++++ lib/Dialect/Triton/Transforms/Combine.cpp | 85 ++++++++++++++ lib/Dialect/TritonGPU/IR/Dialect.cpp | 104 ++++++++++++------ test/Triton/combine.mlir | 35 ++++++ 6 files changed, 269 insertions(+), 36 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 8ac008940..32ebe4123 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -94,7 +94,12 @@ def TT_LoadOp : TT_Op<"load", "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, ]; - let assemblyFormat = "operands attr-dict `:` type($result)"; + // let assemblyFormat = "operands attr-dict `:` type($result)"; + let parser = [{ return mlir::triton::parseLoadOp(parser, result); }]; + + let printer = [{ return mlir::triton::printLoadOp(p, *this); }]; + + let hasCanonicalizer = 1; } def TT_StoreOp : TT_Op<"store", @@ -114,7 +119,12 @@ def TT_StoreOp : TT_Op<"store", OpBuilder<(ins "Value":$ptr, "Value":$value)>, ]; - let assemblyFormat = "operands attr-dict `:` type($value)"; + // let assemblyFormat = "operands attr-dict `:` type($value)"; + let parser = [{ return mlir::triton::parseStoreOp(parser, result); }]; + + let printer = [{ return mlir::triton::printStoreOp(p, *this); }]; + + let hasCanonicalizer = 1; } def TT_GEPOp : TT_Op<"getelementptr", diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index d61e6eb36..3bdccf8db 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -60,7 +60,10 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async", let results = (outs TT_Type:$result); - let assemblyFormat = "operands attr-dict `:` type($ptr) `->` type($result)"; + // let assemblyFormat = "operands attr-dict `:` type($ptr) `->` type($result)"; + let parser = [{ return parseCopyAsyncOp(parser, result); }]; + + let printer = [{ return printCopyAsyncOp(p, *this); }]; // result needs to be of shared layout let verifier = [{ return ::verify(*this); }]; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 007080830..501f19cf8 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -36,6 +36,68 @@ static Type getPointerTypeFromTensor(Type type) { return Type(); } +// Parser & printer for assembly forms +ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { + SmallVector allOperands; + Type resultTypes[1]; + SMLoc allOperandLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(allOperands) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.parseCustomTypeWithFallback(resultTypes[0])) + return failure(); + + result.addTypes(resultTypes); + + SmallVector operandTypes; + operandTypes.push_back(getPointerTypeFromTensor(resultTypes[0])); // ptr + if (allOperands.size() >= 2) + operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask + if (allOperands.size() >= 3) + operandTypes.push_back(resultTypes[0]); // other + + if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc, + result.operands)) + return failure(); + return success(); +} + +void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) { + printer << " "; + printer << loadOp.getOperation()->getOperands(); + printer.printOptionalAttrDict(loadOp->getAttrs(), /*elidedAttrs=*/{}); + printer << " : "; + printer.printStrippedAttrOrType(loadOp.result().getType()); +} + +ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { + SmallVector allOperands; + Type valueType; + SMLoc allOperandLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(allOperands) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.parseCustomTypeWithFallback(valueType)) + return failure(); + + SmallVector operandTypes; + operandTypes.push_back(getPointerTypeFromTensor(valueType)); // ptr + operandTypes.push_back(valueType); // value + if (allOperands.size() >= 3) + operandTypes.push_back(getI1SameShape(valueType)); // mask + + if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc, + result.operands)) + return failure(); + return success(); +} + +void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) { + printer << " "; + printer << storeOp.getOperation()->getOperands(); + printer.printOptionalAttrDict(storeOp->getAttrs(), /*elidedAttrs=*/{}); + printer << " : "; + printer.printStrippedAttrOrType(storeOp.value().getType()); +} + } // namespace triton } // namespace mlir diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 814e6d41f..58233e015 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -92,6 +92,91 @@ public: } }; +// load(ptr, splat(1), ...) -> load(ptr, ...) +// load(ptr, splat(0), other, ...) -> other +struct CanonicalizeMaskedLoadPattern + : public mlir::OpRewritePattern { + CanonicalizeMaskedLoadPattern(mlir::MLIRContext *context) + : OpRewritePattern(context, 1) {} + + mlir::LogicalResult + matchAndRewrite(triton::LoadOp loadOp, + mlir::PatternRewriter &rewriter) const override { + auto mask = loadOp.mask(); + if (!mask) + return mlir::failure(); + + auto constantMask = llvm::dyn_cast(mask.getDefiningOp()); + if (!constantMask) + return mlir::failure(); + + auto splatMask = constantMask.getValue().dyn_cast(); + if (!splatMask) + return mlir::failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), loadOp.ptr(), Value(), Value(), + loadOp.cache(), loadOp.evict(), loadOp.isVolatile()); + } else { + // mask = splat(0) + + // If there's no "other", the value is "undef". Perhaps we want to + // optimize it in the future.x + auto otherVal = loadOp.other(); + if (!otherVal) + return mlir::failure(); + rewriter.replaceOp(loadOp, otherVal); + } + return mlir::success(); + } +}; + +void triton::LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) +// store(ptr, value, splat(0), ...) -> [none] +struct CanonicalizeMaskedStorePattern + : public mlir::OpRewritePattern { + CanonicalizeMaskedStorePattern(mlir::MLIRContext *context) + : OpRewritePattern(context, 1) {} + + mlir::LogicalResult + matchAndRewrite(triton::StoreOp storeOp, + mlir::PatternRewriter &rewriter) const override { + auto mask = storeOp.mask(); + if (!mask) + return mlir::failure(); + + auto constantMask = llvm::dyn_cast(mask.getDefiningOp()); + if (!constantMask) + return mlir::failure(); + + auto splatMask = constantMask.getValue().dyn_cast(); + if (!splatMask) + return mlir::failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp(storeOp, storeOp.ptr(), + storeOp.value()); + } else { + // mask = splat(0) + rewriter.eraseOp(storeOp); + } + return mlir::success(); + } +}; + +void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + #define GEN_PASS_CLASSES #include "triton/Dialect/Triton/Transforms/Passes.h.inc" diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d6275dc89..91fec7d34 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -10,6 +10,37 @@ using namespace mlir; using namespace mlir::triton::gpu; +// Utility +namespace mlir { +namespace triton { + +// Type inference +static Type getI1SameShape(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorType = type.dyn_cast()) + return RankedTensorType::get(tensorType.getShape(), i1Type, + tensorType.getEncoding()); + return Type(); +} + +static Type getPointeeType(Type type) { + if (auto tensorType = type.dyn_cast()) { + // Tensor of pointers + auto shape = tensorType.getShape(); + auto ptrType = tensorType.getElementType().dyn_cast(); + Type pointeeType = ptrType.getPointeeType(); + return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding()); + } else if (auto ptrType = type.dyn_cast()) { + // scalar pointer + Type pointeeType = ptrType.getPointeeType(); + return pointeeType; + } + return Type(); +} + +} // namespace triton +} // namespace mlir + static LogicalResult parseIntAttrValue(AsmParser &parser, const Attribute &attr, unsigned &value, StringRef desc) { auto intAttr = attr.dyn_cast(); @@ -260,6 +291,44 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const { << "}>"; } +//===----------------------------------------------------------------------===// +// CopyAsyncOp +//===----------------------------------------------------------------------===// + +ParseResult parseCopyAsyncOp(OpAsmParser &parser, OperationState &result) { + SmallVector allOperands; + Type resultTypes[1], ptrType; + SMLoc allOperandLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(allOperands) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.parseCustomTypeWithFallback(ptrType) || parser.parseArrow() || + parser.parseCustomTypeWithFallback(resultTypes[0])) + return failure(); + result.addTypes(resultTypes); + + SmallVector operandTypes; + operandTypes.push_back(ptrType); // ptr + if (allOperands.size() >= 2) + operandTypes.push_back(triton::getI1SameShape(ptrType)); // mask + if (allOperands.size() >= 3) + operandTypes.push_back(triton::getPointeeType(ptrType)); // other + + if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc, + result.operands)) + return failure(); + return success(); +} + +void printCopyAsyncOp(OpAsmPrinter &printer, CopyAsyncOp copyAsyncOp) { + printer << " "; + printer << copyAsyncOp.getOperation()->getOperands(); + printer.printOptionalAttrDict(copyAsyncOp->getAttrs(), /*elidedAttrs=*/{}); + printer << " : "; + printer.printStrippedAttrOrType(copyAsyncOp.ptr().getType()); + printer << " -> "; + printer.printStrippedAttrOrType(copyAsyncOp.result().getType()); +} + //===----------------------------------------------------------------------===// // ASM Interface (i.e.: alias) //===----------------------------------------------------------------------===// @@ -282,8 +351,7 @@ public: os << "slice"; return AliasResult::FinalAlias; } */ - OpAsmDialectInterface::getAlias(attr, os); - return AliasResult::FinalAlias; + return OpAsmDialectInterface::getAlias(attr, os); } }; @@ -299,36 +367,6 @@ void TritonGPUDialect::initialize() { addInterfaces(); } -namespace mlir { -namespace triton { - -// Type inference -static Type getI1SameShape(Type type) { - auto i1Type = IntegerType::get(type.getContext(), 1); - if (auto tensorType = type.dyn_cast()) - return RankedTensorType::get(tensorType.getShape(), i1Type, - tensorType.getEncoding()); - return Type(); -} - -static Type getPointeeType(Type type) { - if (auto tensorType = type.dyn_cast()) { - // Tensor of pointers - auto shape = tensorType.getShape(); - auto ptrType = tensorType.getElementType().dyn_cast(); - Type pointeeType = ptrType.getPointeeType(); - return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding()); - } else if (auto ptrType = type.dyn_cast()) { - // scalar pointer - Type pointeeType = ptrType.getPointeeType(); - return pointeeType; - } - return Type(); -} - -} // namespace triton -} // namespace mlir - //===----------------------------------------------------------------------===// // Verification //===----------------------------------------------------------------------===// @@ -352,4 +390,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { // TODO: fill this. return success(); -} \ No newline at end of file +} diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 8bf36af05..3b84966ba 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -70,3 +70,38 @@ func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { // CHECK-NEXT: return %[[cst]] : tensor<8x2xf32> return %bst_out : tensor<8x2xf32> } + +// CHECK-LABEL: @test_canonicalize_masked_load_pattern +func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { + %true_mask = arith.constant dense : tensor<8xi1> + %false_mask = arith.constant dense : tensor<8xi1> + %other_val = arith.constant dense<0.0> : tensor<8xf32> + + // true_mask with other + // CHECK: %[[res1:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + %x = tt.load %ptr, %true_mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + + // true_mask without other + // CHECK: %[[res2:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + %y = tt.load %ptr, %true_mask, %other_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + + // false_mask with other. It should become "other" (i.e., %y) + %z = tt.load %ptr, %false_mask, %y {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + + // CHECK: return %[[res1]], %[[res2]], %[[res2]] : tensor<8xf32>, tensor<8xf32>, tensor<8xf32> + return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32> +} + +// CHECK-LABEL: @test_canonicalize_masked_store_pattern +func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr>, %val: tensor<8xf32>) { + %true_mask = arith.constant dense : tensor<8xi1> + %false_mask = arith.constant dense : tensor<8xi1> + + // CHECK: tt.store %{{.*}}, %{{.*}} : tensor<8xf32> + tt.store %ptr, %val, %true_mask : tensor<8xf32> + + // The following store should disappear. + // CHECK-NEXT: return + tt.store %ptr, %val, %false_mask : tensor<8xf32> + return +}