[BACKEND] llvm::dyn_cast -> llvm::dyn_cast_or_null (#689)

This commit is contained in:
Shintaro Iwasaki
2022-09-21 20:26:40 -07:00
committed by GitHub
parent 15bfd0cb79
commit 940ef3f0ac
4 changed files with 55 additions and 8 deletions

View File

@@ -56,7 +56,7 @@ jobs:
- name: Check cpp style
if: ${{ matrix.runner != 'macos-latest' }}
run: |
sudo apt-get install -y clang-format
pip install clang-format
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)

View File

@@ -71,7 +71,7 @@ public:
mlir::Value falseValue = selectOp.getFalseValue();
auto *loadOpCandidate = trueValue.getDefiningOp();
auto loadOp = llvm::dyn_cast<triton::LoadOp>(loadOpCandidate);
auto loadOp = llvm::dyn_cast_or_null<triton::LoadOp>(loadOpCandidate);
if (!loadOp)
return mlir::failure();
@@ -81,7 +81,7 @@ public:
auto *broadcastOpCandidate = mask.getDefiningOp();
auto broadcastOp =
llvm::dyn_cast<triton::BroadcastOp>(broadcastOpCandidate);
llvm::dyn_cast_or_null<triton::BroadcastOp>(broadcastOpCandidate);
if (!broadcastOp)
return mlir::failure();
@@ -106,7 +106,8 @@ struct CanonicalizeMaskedLoadPattern
if (!mask)
return mlir::failure();
auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
if (!constantMask)
return mlir::failure();
@@ -152,7 +153,8 @@ struct CanonicalizeMaskedStorePattern
if (!mask)
return mlir::failure();
auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
if (!constantMask)
return mlir::failure();

View File

@@ -301,9 +301,17 @@ void LoopPipeliner::emitPrologue() {
}
// If this is a load/async_copy, we need to update the mask
if (llvm::isa<triton::LoadOp, triton::gpu::InsertSliceAsyncOp>(newOp)) {
Value mask = llvm::isa<triton::LoadOp>(newOp) ? newOp->getOperand(1)
: newOp->getOperand(3);
if (Value mask = [&]() {
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(newOp)) {
return loadOp.mask();
} else if (auto insertSliceAsyncOp =
llvm::dyn_cast<triton::gpu::InsertSliceAsyncOp>(
newOp)) {
return insertSliceAsyncOp.mask();
} else {
return mlir::Value();
}
}()) {
// assert(I1 or TensorOf<[I1]>);
OpBuilder::InsertionGuard g(builder);
// TODO: move this out of the loop

View File

@@ -61,6 +61,22 @@ func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %con
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
%false_val = arith.constant dense<0.0> : tensor<8xf32>
// Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
%0 = select %cond, %dummy_load, %false_val : tensor<8xf32>
// Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized.
%real_load = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
%1 = select %cond, %real_load, %false_val : tensor<8xf32>
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
@@ -92,6 +108,19 @@ func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (te
return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern
func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
%other_val = arith.constant dense<0.0> : tensor<8xf32>
// Case: value at the "mask" position is not an "op". Load should not be canonicalized.
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
%x = tt.load %ptr, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
%y = tt.load %ptr, %mask, %other_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
return %x, %y: tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_canonicalize_masked_store_pattern
func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
%true_mask = arith.constant dense<true> : tensor<8xi1>
@@ -105,3 +134,11 @@ func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val:
tt.store %ptr, %val, %false_mask : tensor<8xf32>
return
}
// CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern
func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
// Case: value at the "mask" position is not an "op". Store should not be canonicalized.
// CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
tt.store %ptr, %val, %mask : tensor<8xf32>
return
}