diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index a4db60380..32dc88e25 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -269,7 +269,9 @@ scf::ForOp LoopPipeliner::createNewForOp() { Value newMask = builder.create(mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask)); - nextMapping.map(mask, newMask); + // if mask is defined outside the loop, don't update the map more than once + if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) + nextMapping.map(mask, newMask); } Operation *nextOp = builder.clone(*op, nextMapping); // update mapping of results diff --git a/rewrite-test/jit/matmul/matmul.mlir b/rewrite-test/jit/matmul/matmul.mlir index fd2268ead..a70dfd334 100644 --- a/rewrite-test/jit/matmul/matmul.mlir +++ b/rewrite-test/jit/matmul/matmul.mlir @@ -204,16 +204,14 @@ module { %91 = arith.addi %arg17, %c128 : index %92 = arith.cmpi slt, %91, %47 : index %93 = tt.broadcast %92 : (i1) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %94 = tt.broadcast %92 : (i1) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %95 = arith.andi %94, %93 : tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %96 = "triton_gpu.copy_async"(%arg15, %93, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> - %97 = "triton_gpu.copy_async"(%arg16, %95, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> - %98 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %99 = tt.getelementptr %arg15, %98, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %100 = arith.muli %arg7, %c128_i32 : i32 - %101 = tt.broadcast %100 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %102 = tt.getelementptr %arg16, %101, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - scf.yield %85, %87, %90, %96, %97, %99, %102, %91 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, index + %94 = "triton_gpu.copy_async"(%arg15, %93, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> + %95 = "triton_gpu.copy_async"(%arg16, %93, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> + %96 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %97 = tt.getelementptr %arg15, %96, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %98 = arith.muli %arg7, %c128_i32 : i32 + %99 = tt.broadcast %98 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %100 = tt.getelementptr %arg16, %99, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + scf.yield %85, %87, %90, %94, %95, %97, %100, %91 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, index } %56 = arith.truncf %55#0 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">> to tensor<128x128xf16, #triton_gpu<"coalesced encoding">> %57 = arith.muli %12, %c128_i32 : i32