diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index e0656492a..3d3b4d7a3 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -73,14 +73,14 @@ def TT_LoadOp : TT_Op<"load", TypesMatchWith<"infer mask type from result type", "result", "mask", "getI1SameShape($_self)">, - TypesMatchWith<"infer other type from result type", + TypesMatchWith<"infer other type from result type or none", "result", "other", - "$_self">]> { + "$_self", "($_op.getOperands().size() == 2) || std::equal_to<>()">]> { let summary = "load"; - let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other, + let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional:$other, TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, - BoolAttr:$isVolatile, BoolAttr:$isOtherUnspecified); + BoolAttr:$isVolatile); let results = (outs TT_Type:$result); @@ -89,11 +89,10 @@ def TT_LoadOp : TT_Op<"load", OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict, "bool":$isVolatile, - "bool":$isOtherUnspecified)> + "triton::EvictionPolicy":$evict, "bool":$isVolatile)> ]; - let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($result)"; + let assemblyFormat = "operands attr-dict `:` type($result)"; } def TT_StoreOp : TT_Op<"store", diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 57187cb2f..da0360f9d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -48,10 +48,14 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async", "getPointeeType($_self)">]> { let summary = "copy async"; - let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other, + let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional:$other, TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, - BoolAttr:$isVolatile, - BoolAttr:$isOtherUnspecified); + BoolAttr:$isVolatile); + + let builders = [ + OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)> + ]; let results = (outs TT_Type:$result); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 128d9521a..2c6442e95 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -229,8 +229,7 @@ struct TritonLoadPattern : public OpConversionPattern { Type retType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(), - adaptor.cache(), adaptor.evict(), adaptor.isVolatile(), - adaptor.isOtherUnspecified()); + adaptor.cache(), adaptor.evict(), adaptor.isVolatile()); return success(); } }; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 0e8fd81ef..cd6612677 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -76,14 +76,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()), DenseIntElementsAttr::get( RankedTensorType::get(shape, builder.getI1Type()), true)); - // other Type resultType = RankedTensorType::get(shape, elementType); - ::mlir::Value other = builder.create( - ptr.getLoc(), resultType, - DenseElementsAttr::get(resultType, builder.getZeroAttr(elementType))); state.addOperands(ptr); state.addOperands(mask); - state.addOperands(other); state.addAttribute( cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache)); @@ -92,8 +87,28 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict)); state.addAttribute(isVolatileAttrName(state.name), builder.getBoolAttr(isVolatile)); - state.addAttribute(isOtherUnspecifiedAttrName(state.name), - builder.getBoolAttr(false)); + state.addTypes({resultType}); +} + +void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, + ::mlir::Value ptr, ::mlir::Value mask, + ::mlir::triton::CacheModifier cache, + ::mlir::triton::EvictionPolicy evict, bool isVolatile) { + TensorType ptrType = ptr.getType().dyn_cast(); + Type elementType = + ptrType.getElementType().dyn_cast().getPointeeType(); + auto shape = ptrType.getShape(); + Type resultType = RankedTensorType::get(shape, elementType); + state.addOperands(ptr); + state.addOperands(mask); + state.addAttribute( + cacheAttrName(state.name), + ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache)); + state.addAttribute( + evictAttrName(state.name), + ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict)); + state.addAttribute(isVolatileAttrName(state.name), + builder.getBoolAttr(isVolatile)); state.addTypes({resultType}); } diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index fdb93ff81..6decc9539 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -40,8 +40,8 @@ def CombineGEPPattern : Pat< // select(cond, load(ptrs, broadcast(cond), ???), other) // => load(ptrs, broadcast(cond), other) def CombineSelectMaskedLoadPattern : Pat< - (SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile, $isOtherUnspecified), $falseValue), - (TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile, $isOtherUnspecified)>; + (SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile), $falseValue), + (TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile)>; // broadcast(cst) => cst def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index d44a25c71..6fb17950e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -240,7 +240,7 @@ void LoopPipeliner::emitPrologue() { newOp = builder.create( op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(), loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(), - loadOp.isVolatile(), loadOp.isOtherUnspecified()); + loadOp.isVolatile()); } else llvm_unreachable("This should be LoadOp"); } else @@ -404,7 +404,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { nextMapping.lookupOrDefault(loadOp.ptr()), nextMapping.lookupOrDefault(loadOp.mask()), nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(), - loadOp.evict(), loadOp.isVolatile(), loadOp.isOtherUnspecified()); + loadOp.evict(), loadOp.isVolatile()); } else nextOp = builder.clone(*op, nextMapping); // llvm::errs() << "epilogue cloning...: " << *op << "\n"; diff --git a/python/src/triton.cc b/python/src/triton.cc index 37b979515..a3d472c23 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1321,16 +1321,10 @@ void init_triton_ir(py::module &&m) { if (other.has_value()) { return self.create( loc, resultType, ptrs, mask, other.value(), cacheModifier, - evictionPolicy, isVolatile, false); + evictionPolicy, isVolatile); } else { - mlir::Value dummy_other = self.create( - loc, resultType, - mlir::DenseElementsAttr::get(resultType, - self.getZeroAttr(elementType))); return self.create( - loc, mlir::RankedTensorType::get(shape, elementType), ptrs, - mask, dummy_other, cacheModifier, evictionPolicy, isVolatile, - true); + loc, ptrs, mask, cacheModifier, evictionPolicy, isVolatile); } }) .def("create_masked_store", diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 9114db766..f9f922846 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -46,7 +46,7 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] %19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] - %20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x128xf32> + %20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32> tt.store %19, %20, %cst, : tensor<128x128xf32> return } \ No newline at end of file diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 9dc541363..ba38ee9d7 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -21,10 +21,10 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { - %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> + %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> // CHECK: offset = 0, size = 8192 %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> - %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL> + %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> // CHECK-NEXT: offset = 8192, size = 8192 %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> @@ -49,17 +49,17 @@ func @reusable(%A : !tt.ptr) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<32x128x!tt.ptr, #AL> - %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK: offset = 0, size = 8192 + %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + // CHECK: offset = 0, size = 8192 %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> - %a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #AL> + %a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> // CHECK-NEXT: offset = 8192, size = 8192 %a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A> - %a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> + %a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> // CHECK-NEXT: offset = 16384, size = 8192 %a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> - %a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #AL> + %a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> // CHECK-NEXT: offset = 0, size = 8192 %a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A> %c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -210,5 +210,5 @@ func @multi_blocks_noreuse(%i1 : i1) { // CHECK-NEXT: offset = 1024, size = 1024 %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> return - // CHECK-NEXT: size = 3072 + // CHECK-NEXT: size = 3072 } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index c2a45dbe2..63bb5e89d 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -21,7 +21,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { func @basic_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm // CHECK: llvm.inline_asm - %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32, #blocked0> + %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> return } } @@ -36,7 +36,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-SAME: ld.global.v4.b32 // CHECK: llvm.inline_asm // CHECK-SAME: ld.global.v4.b32 - %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32, #blocked0> + %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> return } } @@ -51,7 +51,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-SAME: ld.global.v2.b32 // CHECK: llvm.inline_asm // CHECK-SAME: ld.global.v2.b32 - %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf16, #blocked0> + %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf16, #blocked0> return } } @@ -64,7 +64,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: masked_load_const_other func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> - %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32, #blocked0> + %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> return } } @@ -189,4 +189,4 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0> return } -} \ No newline at end of file +} diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 05ae64153..8bf36af05 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -49,12 +49,12 @@ func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr>, %con %mask = tt.broadcast %cond : (i1) -> tensor<8xi1> %false_val = arith.constant dense<0.0> : tensor<8xf32> - // CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<8xf32> - %x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<8xf32> + // CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + %x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> %0 = select %cond, %x, %false_val : tensor<8xf32> - // CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = true, isVolatile = false} : tensor<8xf32> - %y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = true, isVolatile = false} : tensor<8xf32> + // CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + %y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> %1 = select %cond, %y, %false_val : tensor<8xf32> // CHECK: return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32> diff --git a/test/Triton/vecadd.mlir b/test/Triton/vecadd.mlir index 28ea31238..d99c3ab1f 100644 --- a/test/Triton/vecadd.mlir +++ b/test/Triton/vecadd.mlir @@ -24,10 +24,10 @@ module { %15:3 = scf.for %arg6 = %12 to %13 step %14 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr>, tensor<256x!tt.ptr>) { %cst_0 = arith.constant 0.000000e+00 : f32 %18 = tt.broadcast %cst_0 : (f32) -> tensor<256xf32> - %19 = tt.load %arg8, %6, %18 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32> + %19 = tt.load %arg8, %6, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32> %cst_1 = arith.constant 0.000000e+00 : f32 %20 = tt.broadcast %cst_1 : (f32) -> tensor<256xf32> - %21 = tt.load %arg9, %6, %20 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32> + %21 = tt.load %arg9, %6, %20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32> %22 = arith.addf %19, %21 : tensor<256xf32> %23 = arith.addf %arg7, %22 : tensor<256xf32> %24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32> diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index 8888967ed..d55e5ed76 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -40,7 +40,7 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> %18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1> - %19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, #blocked1> + %19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1> tt.store %18, %19, %cst, : tensor<64x64xf32, #blocked1> return } \ No newline at end of file diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index cc73c154f..e71d20357 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -33,9 +33,9 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { - %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> + %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> - %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL> + %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -75,9 +75,9 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr : tensor<32x128xi32, #BL> scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { - %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> + %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> - %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL> + %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -106,7 +106,7 @@ func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : %a_mask = arith.constant dense : tensor<128x32xi1, #AL> %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> - %a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> + %a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %b_mask = arith.constant dense : tensor<32x128xi1, #BL> @@ -116,7 +116,7 @@ func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { - %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL> + %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>