From 0ebef11c775d829fcf054fc0a6cbfc4f96cc5a53 Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Mon, 22 Aug 2022 22:00:17 -0700 Subject: [PATCH] [TritonIR] Make mask operand optional (#74) --- include/triton/Dialect/Triton/IR/TritonOps.td | 30 ++++++------ .../Dialect/TritonGPU/IR/TritonGPUOps.td | 17 ++++--- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 5 +- lib/Dialect/Triton/IR/Ops.cpp | 49 +++++++------------ lib/Dialect/Triton/Transforms/Combine.cpp | 40 +++++++++++++++ lib/Dialect/Triton/Transforms/Combine.td | 6 --- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 18 ++++--- python/src/triton.cc | 18 ++----- test/Analysis/test-alignment.mlir | 4 +- test/Conversion/triton_to_llvm.mlir | 2 +- test/Conversion/tritongpu_to_llvm.mlir | 2 +- test/Triton/vecadd.mlir | 4 +- test/TritonGPU/coalesce.mlir | 12 ++--- test/TritonGPU/combine.mlir | 8 +-- 14 files changed, 113 insertions(+), 102 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 3d3b4d7a3..8ac008940 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -22,7 +22,7 @@ class TT_Op traits = []> : // // Use cast ops in arith: // bitcast -// fptoui, fptosi, uitofp, sitofp, +// fptoui, fptosi, uitofp, sitofp, // extf, tructf, // extui, extsi, tructi def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffect, @@ -66,30 +66,32 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect, // def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape, + SameVariadicOperandSize, MemoryEffects<[MemRead]>, TypesMatchWith<"infer ptr type from result type", "result", "ptr", "getPointerTypeFromTensor($_self)">, - TypesMatchWith<"infer mask type from result type", - "result", "mask", - "getI1SameShape($_self)">, + TypesMatchWith<"infer mask type from result type or none", + "result", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 1) || std::equal_to<>()">, TypesMatchWith<"infer other type from result type or none", - "result", "other", - "$_self", "($_op.getOperands().size() == 2) || std::equal_to<>()">]> { + "result", "other", "$_self", + "($_op.getOperands().size() <= 2) || std::equal_to<>()">]> { let summary = "load"; - let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional:$other, + let arguments = (ins TT_PtrTensor:$ptr, Optional:$mask, Optional:$other, TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, BoolAttr:$isVolatile); let results = (outs TT_Type:$result); let builders = [ - // for args with default values OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict, "bool":$isVolatile)> + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, ]; let assemblyFormat = "operands attr-dict `:` type($result)"; @@ -102,17 +104,17 @@ def TT_StoreOp : TT_Op<"store", "value", "ptr", "getPointerTypeFromTensor($_self)">, TypesMatchWith<"infer mask type from value type", - "value", "mask", - "getI1SameShape($_self)">]> { + "value", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()">]> { let summary = "store"; - let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask); + let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, Optional:$mask); let builders = [ OpBuilder<(ins "Value":$ptr, "Value":$value)>, ]; - let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict `:` type($value)"; + let assemblyFormat = "operands attr-dict `:` type($value)"; } def TT_GEPOp : TT_Op<"getelementptr", @@ -257,7 +259,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> { let summary = "atomic cas"; let description = [{ - compare $cmp with data $old at location $ptr, + compare $cmp with data $old at location $ptr, if $old == $cmp, store $val to $ptr, diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index da0360f9d..d61e6eb36 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -40,26 +40,27 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { def TTG_CopyAsyncOp : TTG_Op<"copy_async", [MemoryEffects<[MemRead, MemWrite]>, + SameVariadicOperandSize, TypesMatchWith<"infer mask type from ptr type", - "ptr", "mask", - "getI1SameShape($_self)">, + "ptr", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 1) || std::equal_to<>()">, TypesMatchWith<"infer other type from ptr type", - "ptr", "other", - "getPointeeType($_self)">]> { + "ptr", "other", "getPointeeType($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()">]> { let summary = "copy async"; - let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional:$other, + let arguments = (ins TT_PtrTensor:$ptr, Optional:$mask, Optional:$other, TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, BoolAttr:$isVolatile); let builders = [ - OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict, "bool":$isVolatile)> + OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, ]; let results = (outs TT_Type:$result); - let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)"; + let assemblyFormat = "operands attr-dict `:` type($ptr) `->` type($result)"; // result needs to be of shared layout let verifier = [{ return ::verify(*this); }]; diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 2c6442e95..7a4e3add6 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -226,10 +226,9 @@ struct TritonLoadPattern : public OpConversionPattern { LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type retType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( - op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(), - adaptor.cache(), adaptor.evict(), adaptor.isVolatile()); + op, adaptor.ptr(), adaptor.mask(), adaptor.other(), adaptor.cache(), + adaptor.evict(), adaptor.isVolatile()); return success(); } }; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index cd6612677..007080830 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -49,58 +49,43 @@ namespace mlir { namespace triton { //-- StoreOp -- -// Default mask void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value) { - TensorType ptrType = ptr.getType().dyn_cast(); - auto shape = ptrType.getShape(); - ::mlir::Value mask = builder.create( - ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()), - DenseIntElementsAttr::get( - RankedTensorType::get(shape, builder.getI1Type()), true)); - state.addOperands(ptr); - state.addOperands(value); - state.addOperands(mask); + StoreOp::build(builder, state, ptr, value, mlir::Value()); } //-- LoadOp -- void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) { - TensorType ptrType = ptr.getType().dyn_cast(); - Type elementType = - ptrType.getElementType().dyn_cast().getPointeeType(); - auto shape = ptrType.getShape(); - // mask - ::mlir::Value mask = builder.create( - ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()), - DenseIntElementsAttr::get( - RankedTensorType::get(shape, builder.getI1Type()), true)); - Type resultType = RankedTensorType::get(shape, elementType); - state.addOperands(ptr); - state.addOperands(mask); - state.addAttribute( - cacheAttrName(state.name), - ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache)); - state.addAttribute( - evictAttrName(state.name), - ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict)); - state.addAttribute(isVolatileAttrName(state.name), - builder.getBoolAttr(isVolatile)); - state.addTypes({resultType}); + LoadOp::build(builder, state, ptr, mlir::Value(), mlir::Value(), cache, evict, + isVolatile); } void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value mask, ::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, mlir::Value(), cache, evict, + isVolatile); +} + +void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, + ::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other, + ::mlir::triton::CacheModifier cache, + ::mlir::triton::EvictionPolicy evict, bool isVolatile) { TensorType ptrType = ptr.getType().dyn_cast(); Type elementType = ptrType.getElementType().dyn_cast().getPointeeType(); auto shape = ptrType.getShape(); Type resultType = RankedTensorType::get(shape, elementType); state.addOperands(ptr); - state.addOperands(mask); + if (mask) { + state.addOperands(mask); + if (other) { + state.addOperands(other); + } + } state.addAttribute( cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache)); diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 69f0ced6a..814e6d41f 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -52,6 +52,46 @@ DenseElementsAttr getConstantValue(Builder &builder, Attribute value, } // anonymous namespace +// select(cond, load(ptrs, broadcast(cond), ???), other) +// => load(ptrs, broadcast(cond), other) +class CombineSelectMaskedLoadPattern : public mlir::RewritePattern { +public: + CombineSelectMaskedLoadPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(mlir::SelectOp::getOperationName(), 3, context, + {triton::LoadOp::getOperationName()}) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto selectOp = llvm::dyn_cast(op); + if (!selectOp) + return mlir::failure(); + + mlir::Value trueValue = selectOp.getTrueValue(); + mlir::Value falseValue = selectOp.getFalseValue(); + + auto *loadOpCandidate = trueValue.getDefiningOp(); + auto loadOp = llvm::dyn_cast(loadOpCandidate); + if (!loadOp) + return mlir::failure(); + + mlir::Value mask = loadOp.mask(); + if (!mask) + return mlir::failure(); + + auto *broadcastOpCandidate = mask.getDefiningOp(); + auto broadcastOp = + llvm::dyn_cast(broadcastOpCandidate); + if (!broadcastOp) + return mlir::failure(); + + rewriter.replaceOpWithNewOp( + op, loadOp.ptr(), loadOp.mask(), falseValue, loadOp.cache(), + loadOp.evict(), loadOp.isVolatile()); + return mlir::success(); + } +}; + #define GEN_PASS_CLASSES #include "triton/Dialect/Triton/Transforms/Passes.h.inc" diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index 6decc9539..8881b33c2 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -37,12 +37,6 @@ def CombineGEPPattern : Pat< (TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1), (TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>; -// select(cond, load(ptrs, broadcast(cond), ???), other) -// => load(ptrs, broadcast(cond), other) -def CombineSelectMaskedLoadPattern : Pat< - (SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile), $falseValue), - (TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile)>; - // broadcast(cst) => cst def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">; def CombineBroadcastConstantPattern : Pat< diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 3cf54d7eb..4bf30296f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -391,14 +391,16 @@ scf::ForOp LoopPipeliner::createNewForOp() { if (loads.contains(op->getResult(0))) { auto loadOp = llvm::cast(op); Value mask = loadOp.mask(); - Value splatCond = builder.create( - mask.getLoc(), mask.getType(), nextLoopCond); - Value newMask = builder.create( - mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask)); - // if mask is defined outside the loop, don't update the map more than - // once - if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) - nextMapping.map(mask, newMask); + if (mask) { + Value splatCond = builder.create( + mask.getLoc(), mask.getType(), nextLoopCond); + Value newMask = builder.create( + mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask)); + // if mask is defined outside the loop, don't update the map more than + // once + if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) + nextMapping.map(mask, newMask); + } // TODO: more elegant way to do this? nextOp = builder.create( op->getLoc(), loadsMapping[op->getResult(0)].getType(), diff --git a/python/src/triton.cc b/python/src/triton.cc index f0b78e15f..549498175 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1442,21 +1442,9 @@ void init_triton_ir(py::module &&m) { mlir::triton::EvictionPolicy evictionPolicy, bool isVolatile) -> mlir::Value { auto loc = self.getUnknownLoc(); - auto ptrType = ptrs.getType().dyn_cast(); - std::vector shape = ptrType.getShape(); - mlir::Type elementType = ptrType.getElementType() - .dyn_cast() - .getPointeeType(); - mlir::Type resultType = - mlir::RankedTensorType::get(shape, elementType); - if (other.has_value()) { - return self.create( - loc, resultType, ptrs, mask, other.value(), cacheModifier, - evictionPolicy, isVolatile); - } else { - return self.create( - loc, ptrs, mask, cacheModifier, evictionPolicy, isVolatile); - } + return self.create( + loc, ptrs, mask, other.value_or(mlir::Value()), cacheModifier, + evictionPolicy, isVolatile); }) .def("create_masked_store", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val, diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index f9f922846..c70a81c8e 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -47,6 +47,6 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t %19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] %20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32> - tt.store %19, %20, %cst, : tensor<128x128xf32> + tt.store %19, %20, %cst : tensor<128x128xf32> return -} \ No newline at end of file +} diff --git a/test/Conversion/triton_to_llvm.mlir b/test/Conversion/triton_to_llvm.mlir index ec5c470ca..1637d11af 100644 --- a/test/Conversion/triton_to_llvm.mlir +++ b/test/Conversion/triton_to_llvm.mlir @@ -31,7 +31,7 @@ func @test_store_splat(%ptr: !tt.ptr) { // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.b32 [ $1 + 0 ], { $2 };", // CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr, i32) -> !llvm.struct<()> - tt.store %ptrs, %vs, %mask, {} : tensor<128xf32> + tt.store %ptrs, %vs, %mask : tensor<128xf32> return } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 4cad33ff3..cde1f85cd 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -186,7 +186,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr, i32) -> !llvm.struct<()> // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr, i32) -> !llvm.struct<()> - tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0> + tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0> return } } diff --git a/test/Triton/vecadd.mlir b/test/Triton/vecadd.mlir index d99c3ab1f..8f3642373 100644 --- a/test/Triton/vecadd.mlir +++ b/test/Triton/vecadd.mlir @@ -38,7 +38,7 @@ module { } %16 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr> %17 = tt.getelementptr %16, %4 : tensor<256x!tt.ptr> - tt.store %17, %15#0, %6, : tensor<256xf32> + tt.store %17, %15#0, %6 : tensor<256xf32> return } } @@ -124,7 +124,7 @@ module { // } // %53 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // %54 = tt.getelementptr %53, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// tt.store %54, %52#0, %6, : tensor<256xf32, #triton_gpu<"coalesced encoding">> +// tt.store %54, %52#0, %6 : tensor<256xf32, #triton_gpu<"coalesced encoding">> // return // } // } diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index d55e5ed76..f2ac1f65c 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -3,7 +3,7 @@ #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> - + // CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> // CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> @@ -15,9 +15,9 @@ // CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]> // CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]> // CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]] -func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: i32 {tt.divisibility = 16 : i32}, - %arg2: !tt.ptr {tt.divisibility = 16 : i32}, +func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { %cst = arith.constant dense : tensor<64x64xi1, #blocked1> %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> @@ -41,6 +41,6 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> %18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1> %19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1> - tt.store %18, %19, %cst, : tensor<64x64xf32, #blocked1> + tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1> return -} \ No newline at end of file +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index c34babdff..5bdc3c10f 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -86,7 +86,7 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt // CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> // CHECK: %22 = tt.getelementptr %18, %19 : tensor<64x64x!tt.ptr, [[col_layout]]> // CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]> - // CHECK: tt.store %22, %23, %cst_1, : tensor<64x64xf32, [[col_layout]]> + // CHECK: tt.store %22, %23, %cst_1 : tensor<64x64xf32, [[col_layout]]> // CHECK: return %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> %cst_0 = arith.constant dense : tensor<64x64xi1, #blocked1> @@ -117,7 +117,7 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt %24 = triton_gpu.convert_layout %18 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked4> %25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4> %26 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked4> - tt.store %24, %25, %26, : tensor<64x64xf32, #blocked4> + tt.store %24, %25, %26 : tensor<64x64xf32, #blocked4> return } @@ -170,6 +170,6 @@ func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %ar %20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1> %22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1> - tt.store %20, %21, %22, : tensor<64x64xf32, #blocked1> + tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1> return -} \ No newline at end of file +}