From b2b793dfb5467d4bd3ac915cf44ad611f2f9b4ce Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 6 Dec 2022 23:29:50 -0800 Subject: [PATCH] [FRONTEND][BACKEND] Fixes for cat / reshape / addptr (#959) Most notably, this PR: - changes the traits (and assembly format) of addptr so it can handle offsets that have arbitrary integer width. - adds support for `cat` --- include/triton/Dialect/Triton/IR/TritonOps.td | 9 +-- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 38 ++++++++++++ .../TritonToTritonGPU/TritonToTritonGPU.cpp | 20 +++++- lib/Dialect/Triton/IR/Traits.cpp | 2 +- lib/Dialect/Triton/Transforms/Combine.cpp | 2 +- lib/Dialect/Triton/Transforms/Combine.td | 9 +-- python/tests/test_backend.py | 4 +- python/triton/compiler.py | 2 +- python/triton/language/__init__.py | 4 ++ python/triton/language/core.py | 11 +++- python/triton/language/semantic.py | 8 ++- test/Analysis/test-alias.mlir | 4 +- test/Analysis/test-alignment.mlir | 24 +++---- test/Analysis/test-allocation.mlir | 4 +- test/Analysis/test-membar.mlir | 4 +- test/Conversion/triton_ops.mlir | 6 +- test/Conversion/tritongpu_to_llvm.mlir | 34 +++++----- test/Triton/combine.mlir | 16 ++--- test/Triton/vecadd.mlir | 36 +++++------ test/TritonGPU/coalesce.mlir | 8 +-- test/TritonGPU/combine.mlir | 62 +++++++++---------- test/TritonGPU/loop-pipeline.mlir | 10 +-- test/TritonGPU/matmul.mlir | 10 +-- test/TritonGPU/prefetch.mlir | 4 +- 24 files changed, 199 insertions(+), 132 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 258cd41b2..fe744339c 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -103,15 +103,12 @@ def TT_AddPtrOp : TT_Op<"addptr", SameOperandsAndResultShape, SameOperandsAndResultEncoding, TypesMatchWith<"result type matches ptr type", - "result", "ptr", "$_self">, - TypesMatchWith<"result shape matches offset shape", - "result", "offset", - "getI32SameShape($_self)">]> { - let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset); + "result", "ptr", "$_self">]> { + let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset); let results = (outs TT_PtrLike:$result); - let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)"; + let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)"; } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 33fd3d932..8af05b0f3 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1707,6 +1707,43 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( /// ====================== reduce codegen end ========================== +/// ====================== cat codegen begin ========================== + +struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern { + using OpAdaptor = typename CatOp::Adaptor; + + explicit CatOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} + + LogicalResult + matchAndRewrite(CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = op.getType().template cast(); + unsigned elems = getElemsPerThread(resultTy); + Type elemTy = + this->getTypeConverter()->convertType(resultTy.getElementType()); + SmallVector types(elems, elemTy); + // unpack input values + auto lhsVals = getElementsFromStruct(loc, adaptor.lhs(), rewriter); + auto rhsVals = getElementsFromStruct(loc, adaptor.rhs(), rewriter); + // concatenate (and potentially reorder) values + SmallVector retVals; + for(Value v: lhsVals) + retVals.push_back(v); + for(Value v: rhsVals) + retVals.push_back(v); + // pack and replace + Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); + Value ret = getStructFromElements(loc, retVals, rewriter, structTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; + +/// ====================== cat codegen end ========================== + template struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern { using OpAdaptor = typename SourceOp::Adaptor; @@ -4537,6 +4574,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, benefit); patterns.add(typeConverter, allocation, smem, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index e4a3e8064..42150c362 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -251,6 +251,22 @@ struct TritonDotPattern : public OpConversionPattern { } }; +struct TritonCatPattern : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // For now, this behaves like generic, but this will evolve when + // we add support for `can_reorder=False` + Type retType = this->getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()); + return success(); + } + +}; + struct TritonTransPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -433,7 +449,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonGenericPattern, TritonGenericPattern, TritonGenericPattern, TritonBroadcastPattern, - TritonGenericPattern, TritonReducePattern, + TritonGenericPattern, + TritonCatPattern, + TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>( diff --git a/lib/Dialect/Triton/IR/Traits.cpp b/lib/Dialect/Triton/IR/Traits.cpp index dabd6b9fc..eede3f406 100644 --- a/lib/Dialect/Triton/IR/Traits.cpp +++ b/lib/Dialect/Triton/IR/Traits.cpp @@ -19,7 +19,7 @@ mlir::OpTrait::impl::verifySameOperandsAndResultEncoding(Operation *op) { for (auto resultType : op->getResultTypes()) if (failed(verifySameEncoding(resultType, type))) return op->emitOpError() - << "requires the same shape for all operands and results"; + << "requires the same encoding for all operands and results"; return verifySameOperandsEncoding(op); } diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 33bc1d0fb..2b72c14a4 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -196,7 +196,7 @@ public: patterns.add(context); // %} patterns.add(context); - patterns.add(context); + // patterns.add(context); patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index 1b84cc7f3..14f286b26 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -29,13 +29,14 @@ def CombineDotAddFRevPattern : Pat< (TT_DotOp $a, $b, $d, $allowTF32), [(Constraint> $c)]>; - +// TODO: this fails for addptr(addptr(ptr, i32), i64) +// Commented out until fixed // addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1)) // Note: leave (sub %c0, %c0) canceling to ArithmeticDialect // (ref: ArithmeticCanonicalization.td) -def CombineAddPtrPattern : Pat< - (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1), - (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>; +// def CombineAddPtrPattern : Pat< +// (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1), +// (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>; // broadcast(cst) => cst def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">; diff --git a/python/tests/test_backend.py b/python/tests/test_backend.py index 06f36e43b..f99947561 100644 --- a/python/tests/test_backend.py +++ b/python/tests/test_backend.py @@ -64,12 +64,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src> %8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src> %9 = arith.addi %8, %7 : tensor<128x128xi32, #src> - %10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr, #src> + %10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr, #src>, tensor<128x128xi32, #src> %11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src> %3 = tt.splat %arg1 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #dst> %12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst> %13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst> - %14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr, #dst> + %14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr, #dst>, tensor<128x128xi32, #dst> tt.store %14, %13 : tensor<128x128xf16, #dst> return } diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 25133bff1..d66dbfd50 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -371,6 +371,7 @@ class CodeGenerator(ast.NodeVisitor): # 1. we have an orelse node # or # 2. the then block defines new variable + else_defs = {} if then_defs or node.orelse: if node.orelse: self.lscope = liveins @@ -381,7 +382,6 @@ class CodeGenerator(ast.NodeVisitor): else_defs = self.local_defs.copy() else: # collect else_defs - else_defs = {} for name in then_defs: if name in liveins: assert self.is_triton_tensor(then_defs[name]) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 4b7df9515..881ad01b1 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -55,6 +55,7 @@ from .core import ( printf, program_id, ravel, + reshape, sigmoid, sin, softmax, @@ -70,6 +71,7 @@ from .core import ( uint64, uint8, umulhi, + view, void, where, xor_sum, @@ -149,6 +151,7 @@ __all__ = [ "randn", "randn4x", "ravel", + "reshape", "sigmoid", "sin", "softmax", @@ -165,6 +168,7 @@ __all__ = [ "uint64", "uint8", "umulhi", + "view", "void", "where", "xor_sum", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index b9542906c..a099139fc 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -731,7 +731,7 @@ def trans(input, _builder=None): return semantic.trans(input, _builder) @builtin -def cat(input, other, _builder=None): +def cat(input, other, can_reorder=False, _builder=None): """ Concatenate the given blocks @@ -739,8 +739,12 @@ def cat(input, other, _builder=None): :type input: :param other: The second input tensor. :type other: + :param reorder: Compiler hint. If true, the compiler is + allowed to reorder elements while concatenating inputs. + Only use if the order does not matter (e.g., result is + only used in reduction ops) """ - return semantic.cat(input, other, _builder) + return semantic.cat(input, other, can_reorder, _builder) @builtin @@ -761,7 +765,8 @@ def view(input, shape, _builder=None): @builtin def reshape(input, shape, _builder=None): # TODO: should be more than just a view - return view(input, shape, _builder) + shape = [x.value for x in shape] + return semantic.view(input, shape, _builder) # ----------------------- # Linear Algebra diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 104d5ac54..ecacf47ec 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -498,9 +498,11 @@ def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty) -def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: - # TODO: check types - return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type) +def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor: + assert can_reorder, "current implementation of `cat` always may reorder elements" + assert len(lhs.shape) == 1 + ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) + return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type) def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor: if len(input.shape) != 2: diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index e81480297..d7bde0303 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -27,8 +27,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT> %c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> - %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> - %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } return diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 0f2a45f03..312194f21 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -18,7 +18,7 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x1x!tt.ptr> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1] - %6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1] %7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] @@ -26,13 +26,13 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1] %9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32> // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1] - %10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr> + %10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1] %11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr> // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] - %13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr> + %13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1] %14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] @@ -44,7 +44,7 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [128, 1] %18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32> // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] - %19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr> + %19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] %20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32> tt.store %19, %20, %cst : tensor<128x128xf32> @@ -72,7 +72,7 @@ func @store_constant_align(%addr: !tt.ptr {tt.divisibility = 16 : i32}, %n: // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] %5 = tt.splat %addr : (!tt.ptr) -> tensor<128x!tt.ptr> // CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1] - %6 = tt.addptr %5, %4 : tensor<128x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<128x!tt.ptr>, tensor<128xi32> // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] %9 = tt.splat %n : (i32) -> tensor<128xi32> // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [16] @@ -97,9 +97,9 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %ar %3 = tt.splat %1 : (i32) -> tensor<64xi32> %4 = arith.addi %3, %2 : tensor<64xi32> %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x!tt.ptr> - %6 = tt.addptr %5, %4 : tensor<64x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<64x!tt.ptr>, tensor<64xi32> %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr>, tensor<64xi32> %9 = tt.splat %n_elements : (i32) -> tensor<64xi32> // CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [16] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> ) %mask = arith.cmpi slt, %4, %9 : tensor<64xi32> @@ -107,8 +107,8 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %ar %12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> %13 = arith.addf %11, %12 : tensor<64xf32> %14 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x!tt.ptr> - // CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr> ) - %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr> + // CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr>, tensor<64xi32> ) + %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr>, tensor<64xi32> tt.store %15, %13, %mask : tensor<64xf32> return } @@ -125,9 +125,9 @@ func @vecadd_mask_align_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %3 = tt.splat %1 : (i32) -> tensor<64xi32> %4 = arith.addi %3, %2 : tensor<64xi32> %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x!tt.ptr> - %6 = tt.addptr %5, %4 : tensor<64x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<64x!tt.ptr>, tensor<64xi32> %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr>, tensor<64xi32> %9 = tt.splat %n_elements : (i32) -> tensor<64xi32> // CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> ) %10 = arith.cmpi slt, %4, %9 : tensor<64xi32> @@ -135,7 +135,7 @@ func @vecadd_mask_align_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> %13 = arith.addf %11, %12 : tensor<64xf32> %14 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x!tt.ptr> - %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr> + %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr>, tensor<64xi32> tt.store %15, %13, %10 : tensor<64xf32> return } diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 28c73b45f..b69041f09 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -35,8 +35,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> - %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> - %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } return diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 42ff6c9c3..a8f48c472 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -33,8 +33,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT> %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> - %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> - %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } return diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir index e9d484887..e9ee50243 100644 --- a/test/Conversion/triton_ops.mlir +++ b/test/Conversion/triton_ops.mlir @@ -38,19 +38,19 @@ func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) { func @addptr_ops(%scalar_ptr: !tt.ptr, %scalar_i32: i32) { // scalar -> scalar // CHECK: !tt.ptr - %0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr + %0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr, i32 // 0D tensor -> 0D tensor %tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr) -> tensor> %tensor_i32_0d = tt.splat %scalar_i32 : (i32) -> tensor // CHECK: tensor> - %1 = tt.addptr %tensor_ptr_0d, %tensor_i32_0d : tensor> + %1 = tt.addptr %tensor_ptr_0d, %tensor_i32_0d : tensor>, tensor // 1D tensor -> 1D tensor %tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr) -> tensor<16x!tt.ptr> %tensor_i32_1d = tt.splat %scalar_i32 : (i32) -> tensor<16xi32> // CHECK: tensor<16x!tt.ptr> - %2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr> + %2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr>, tensor<16xi32> return } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index e2bd31df0..e56b20e31 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -92,9 +92,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0> %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> - %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> - %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> // Load 4 elements from vector0 // CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; @@ -111,7 +111,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> - %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> // Store 4 elements to global // CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; @@ -136,9 +136,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0> %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> - %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> - %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> // Load 4 elements from A with single one vectorized load instruction // CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; @@ -150,7 +150,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> - %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> // Store 4 elements to global with single one vectorized store instruction // CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }; @@ -173,9 +173,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %3 = tt.splat %1 : (i32) -> tensor<64xi32, #blocked> %4 = arith.addi %3, %2 : tensor<64xi32, #blocked> %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x!tt.ptr, #blocked> - %6 = tt.addptr %5, %4 : tensor<64x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<64x!tt.ptr, #blocked>, tensor<64xi32, #blocked> %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x!tt.ptr, #blocked> - %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr, #blocked>, tensor<64xi32, #blocked> %9 = tt.splat %n_elements : (i32) -> tensor<64xi32, #blocked> %10 = "triton_gpu.cmpi"(%4, %9) {predicate = 2 : i64} : (tensor<64xi32, #blocked>, tensor<64xi32, #blocked>) -> tensor<64xi1, #blocked> // load op has a vector width = 1 due to the %mask's alignment @@ -184,7 +184,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32, #blocked> %13 = arith.addf %11, %12 : tensor<64xf32, #blocked> %14 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x!tt.ptr, #blocked> - %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr, #blocked> + %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr, #blocked>, tensor<64xi32, #blocked> tt.store %15, %13, %10 : tensor<64xf32, #blocked> return } @@ -203,9 +203,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { %3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0> %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> - %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> - %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> // Load 8 elements from A with two vectorized load instruction // CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; @@ -219,7 +219,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { %10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> - %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> // Store 8 elements to global with two vectorized store instruction // CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }; @@ -317,7 +317,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { func @basic_addptr(%arg0 : tensor<256x!tt.ptr,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { // CHECK: llvm.getelementptr // CHECK: llvm.getelementptr - %0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr, #blocked0> + %0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> return } } @@ -411,7 +411,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL> %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL> %a_init = tt.splat %arg0 : (!tt.ptr) -> tensor<16x64x!tt.ptr, #AL> - %a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr, #AL> + %a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr, #AL>, tensor<16x64xi32, #AL> %tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf16, #A> %index = arith.constant 1 : i32 @@ -450,7 +450,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL> %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL> %a_init = tt.splat %arg0 : (!tt.ptr) -> tensor<16x64x!tt.ptr, #AL> - %a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr, #AL> + %a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr, #AL>, tensor<16x64xi32, #AL> %tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf32, #A> %index = arith.constant 1 : i32 @@ -491,7 +491,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x32xi32, #block3>) -> tensor<16x32xi32, #AL> %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x32xi32, #AL> %a_init = tt.splat %arg0 : (!tt.ptr) -> tensor<16x32x!tt.ptr, #AL> - %a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr, #AL> + %a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr, #AL>, tensor<16x32xi32, #AL> %tensor = triton_gpu.alloc_tensor : tensor<2x16x32xf32, #A> %index = arith.constant 1 : i32 @@ -535,7 +535,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<32x32xi32, #block3>) -> tensor<32x32xi32, #AL> %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<32x32xi32, #AL> %a_init = tt.splat %arg0 : (!tt.ptr) -> tensor<32x32x!tt.ptr, #AL> - %a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr, #AL> + %a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> %tensor = triton_gpu.alloc_tensor : tensor<2x32x32xf32, #A> %index = arith.constant 1 : i32 diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index b847f3163..c8c1f2962 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -22,28 +22,30 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32 return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32> } -// CHECK-LABEL: @test_combine_addptr_pattern + +// COM: CHECK-LABEL: @test_combine_addptr_pattern func @test_combine_addptr_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> { %off0 = arith.constant 10 : i32 %off1 = arith.constant 15 : i32 // 10 + 15 = 25 - // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32> + // COM: CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32> %base_ = tt.broadcast %base : (!tt.ptr) -> tensor<8x!tt.ptr> - // CHECK-NEXT: %[[tmp0:.*]] = tt.broadcast %{{.*}} : (!tt.ptr) -> tensor<8x!tt.ptr> + // COM: CHECK-NEXT: %[[tmp0:.*]] = tt.broadcast %{{.*}} : (!tt.ptr) -> tensor<8x!tt.ptr> %idx0 = tt.broadcast %off0 : (i32) -> tensor<8xi32> %idx1 = tt.broadcast %off1 : (i32) -> tensor<8xi32> - - // CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr> - %ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr> - %ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr> + + // COM: CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr>, tensor<8xi32> + %ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr>, tensor<8xi32> + %ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr>, tensor<8xi32> return %ptr1 : tensor<8x!tt.ptr> } + // CHECK-LABEL: @test_combine_select_masked_load_pattern func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) { %mask = tt.broadcast %cond : (i1) -> tensor<8xi1> diff --git a/test/Triton/vecadd.mlir b/test/Triton/vecadd.mlir index 391e80650..0b69ef305 100644 --- a/test/Triton/vecadd.mlir +++ b/test/Triton/vecadd.mlir @@ -11,9 +11,9 @@ module { %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32> %6 = arith.cmpi slt, %4, %5 : tensor<256xi32> %7 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr>, tensor<256xi32> %9 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr> - %10 = tt.addptr %9, %4 : tensor<256x!tt.ptr> + %10 = tt.addptr %9, %4 : tensor<256x!tt.ptr>, tensor<256xi32> %cst = arith.constant 0.000000e+00 : f32 %11 = tt.broadcast %cst : (f32) -> tensor<256xf32> %c0_i32 = arith.constant 0 : i32 @@ -31,13 +31,13 @@ module { %22 = arith.addf %19, %21 : tensor<256xf32> %23 = arith.addf %arg7, %22 : tensor<256xf32> %24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32> - %25 = tt.addptr %arg8, %24 : tensor<256x!tt.ptr> + %25 = tt.addptr %arg8, %24 : tensor<256x!tt.ptr>, tensor<256xi32> %26 = tt.broadcast %arg5 : (i32) -> tensor<256xi32> - %27 = tt.addptr %arg9, %26 : tensor<256x!tt.ptr> + %27 = tt.addptr %arg9, %26 : tensor<256x!tt.ptr>, tensor<256xi32> scf.yield %23, %25, %27 : tensor<256xf32>, tensor<256x!tt.ptr>, tensor<256x!tt.ptr> } %16 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr> - %17 = tt.addptr %16, %4 : tensor<256x!tt.ptr> + %17 = tt.addptr %16, %4 : tensor<256x!tt.ptr>, tensor<256xi32> tt.store %17, %15#0, %6 : tensor<256xf32> return } @@ -57,9 +57,9 @@ module { // %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> // %6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu<"coalesced encoding">>, tensor<256xi32, #triton_gpu<"coalesced encoding">>) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> // %7 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> // %9 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %10 = tt.addptr %9, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %10 = tt.addptr %9, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> // %11 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %12 = arith.index_cast %arg4 : i32 to index // %13 = arith.cmpi slt, %c0, %12 : index @@ -72,9 +72,9 @@ module { // %20 = arith.andi %6, %19 : tensor<256xi1, #triton_gpu<"coalesced encoding">> // %21 = triton_gpu.copy_async %10, %20, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %22 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %23 = tt.addptr %8, %22, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %23 = tt.addptr %8, %22, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> // %24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %25 = tt.addptr %10, %24, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %25 = tt.addptr %10, %24, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> // %26 = arith.cmpi slt, %c32, %12 : index // %27 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %28 = tt.broadcast %26 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> @@ -85,9 +85,9 @@ module { // %33 = arith.andi %6, %32 : tensor<256xi1, #triton_gpu<"coalesced encoding">> // %34 = triton_gpu.copy_async %25, %33, %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %35 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %36 = tt.addptr %23, %35, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %36 = tt.addptr %23, %35, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> // %37 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %38 = tt.addptr %25, %37, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %38 = tt.addptr %25, %37, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> // %39 = arith.cmpi slt, %c64, %12 : index // %40 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %41 = tt.broadcast %39 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> @@ -98,16 +98,16 @@ module { // %46 = arith.andi %6, %45 : tensor<256xi1, #triton_gpu<"coalesced encoding">> // %47 = triton_gpu.copy_async %38, %46, %44 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %48 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %49 = tt.addptr %36, %48, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %49 = tt.addptr %36, %48, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> // %50 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %51 = tt.addptr %38, %50, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %51 = tt.addptr %38, %50, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> // %52:12 = scf.for %arg6 = %c0 to %12 step %c32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10, %arg10 = %17, %arg11 = %30, %arg12 = %43, %arg13 = %21, %arg14 = %34, %arg15 = %47, %arg16 = %51, %arg17 = %49, %arg18 = %c64) -> (tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, index) { // %55 = arith.addf %arg10, %arg13 : tensor<256xf32, #triton_gpu<"coalesced encoding">> // %56 = arith.addf %arg7, %55 : tensor<256xf32, #triton_gpu<"coalesced encoding">> // %57 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %58 = tt.addptr %arg8, %57, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %58 = tt.addptr %arg8, %57, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> // %59 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %60 = tt.addptr %arg9, %59, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %60 = tt.addptr %arg9, %59, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> // %61 = arith.addi %arg18, %c32 : index // %62 = arith.cmpi slt, %61, %12 : index // %63 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> @@ -117,13 +117,13 @@ module { // %67 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %68 = triton_gpu.copy_async %arg16, %65, %67 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %69 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %70 = tt.addptr %arg17, %69, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %70 = tt.addptr %arg17, %69, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> // %71 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %72 = tt.addptr %arg16, %71, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %72 = tt.addptr %arg16, %71, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> // scf.yield %56, %58, %60, %arg11, %arg12, %66, %arg14, %arg15, %68, %72, %70, %61 : tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, index // } // %53 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> // tt.store %54, %52#0, %6 : tensor<256xf32, #triton_gpu<"coalesced encoding">> // return // } diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index 23b083ec8..60e359f52 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -31,20 +31,20 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1> %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> - %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> + %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> %6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> - %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1> + %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %11 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> - %12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1> + %12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> %13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2> %14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2> %15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> - %18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1> + %18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1> tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1> return diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 74c4067b6..b4d2da376 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -74,20 +74,20 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt %2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1> %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> - %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> + %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> %6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> - %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1> + %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %11 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> - %12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1> + %12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> %13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2> %14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2> %15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> - %18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1> + %18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %19 = triton_gpu.convert_layout %10 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked3> %20 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3> %21 = triton_gpu.convert_layout %cst : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3> @@ -106,7 +106,7 @@ func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %ar // CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr, [[row_layout]]>) // CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[row_layout]]> // CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[row_layout]]> - // CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr, [[row_layout]]> + // CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr, [[row_layout]]>, tensor<64x64xi32, [[row_layout]]> // CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr, [[row_layout]]> // CHECK-NEXT: } // CHECK-NEXT: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout_novec]]> @@ -123,12 +123,12 @@ func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %ar %2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1> %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> - %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> + %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> %6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> - %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1> + %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1>) { %23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked3> %24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3> @@ -136,17 +136,17 @@ func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %ar %26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3> %27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1> %28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1> - %29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr, #blocked1> + %29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1> } %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> - %13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr, #blocked1> + %13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> %14 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2> %15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2> %16 = tt.broadcast %13 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %17 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %18 = triton_gpu.convert_layout %17 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> - %19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr, #blocked1> + %19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1> %22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1> @@ -160,27 +160,27 @@ func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %4 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %6 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %9 = arith.addi %6, %7 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %11 = arith.addi %4, %5 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %12 = tt.addptr %8, %9 : tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %14 = triton_gpu.convert_layout %13 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>> - %15 = tt.addptr %10, %11 : tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %17 = triton_gpu.convert_layout %16 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>> + %2 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1> + %4 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1> + %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1> + %6 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1> + %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1> + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #layout1> + %9 = arith.addi %6, %7 : tensor<256xi32, #layout1> + %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #layout1> + %11 = arith.addi %4, %5 : tensor<256xi32, #layout1> + %12 = tt.addptr %8, %9 : tensor<256x!tt.ptr, #layout1>, tensor<256xi32, #layout1> + %13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #layout1> + %14 = triton_gpu.convert_layout %13 : (tensor<256xf32, #layout1>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>> + %15 = tt.addptr %10, %11 : tensor<256x!tt.ptr, #layout1>, tensor<256xi32, #layout1> + %16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #layout1> + %17 = triton_gpu.convert_layout %16 : (tensor<256xf32, #layout1>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>> %18 = arith.addf %14, %17 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>> - %19 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %20 = arith.addi %2, %3 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %21 = tt.addptr %19, %20 : tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - tt.store %21, %22 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %19 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #layout1> + %20 = arith.addi %2, %3 : tensor<256xi32, #layout1> + %21 = tt.addptr %19, %20 : tensor<256x!tt.ptr, #layout1>, tensor<256xi32, #layout1> + %22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #layout1> + tt.store %21, %22 : tensor<256xf32, #layout1> return } diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index a8d0fef14..f8bb34798 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -65,8 +65,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> - %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> - %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } return @@ -125,8 +125,8 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> - %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> - %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } } @@ -176,7 +176,7 @@ func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> - %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } return diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir index a3573c00a..9bd5318e1 100644 --- a/test/TritonGPU/matmul.mlir +++ b/test/TritonGPU/matmul.mlir @@ -46,7 +46,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6 %31 = tt.broadcast %29 : (tensor<1x64xi32>) -> tensor<64x64xi32> %32 = arith.addi %30, %31 : tensor<64x64xi32> %33 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x64x!tt.ptr> - %34 = tt.addptr %33, %32 : tensor<64x64x!tt.ptr> + %34 = tt.addptr %33, %32 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %35 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32> %36 = tt.splat %arg8 : (i32) -> tensor<64x1xi32> %37 = arith.muli %35, %36 : tensor<64x1xi32> @@ -57,7 +57,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6 %42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<64x64xi32> %43 = arith.addi %41, %42 : tensor<64x64xi32> %44 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x64x!tt.ptr> - %45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr> + %45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %46 = arith.index_cast %arg5 : i32 to index %47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>) { %76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32> @@ -66,10 +66,10 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6 %79 = arith.addf %arg13, %78 : tensor<64x64xf32> %80 = arith.muli %arg7, %c64_i32 : i32 %81 = tt.splat %80 : (i32) -> tensor<64x64xi32> - %82 = tt.addptr %arg14, %81 : tensor<64x64x!tt.ptr> + %82 = tt.addptr %arg14, %81 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %83 = arith.muli %arg8, %c64_i32 : i32 %84 = tt.splat %83 : (i32) -> tensor<64x64xi32> - %85 = tt.addptr %arg15, %84 : tensor<64x64x!tt.ptr> + %85 = tt.addptr %arg15, %84 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> scf.yield %79, %82, %85 : tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr> } %48 = arith.muli %12, %c64_i32 : i32 @@ -90,7 +90,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6 %63 = tt.broadcast %61 : (tensor<1x64xi32>) -> tensor<64x64xi32> %64 = arith.addi %62, %63 : tensor<64x64xi32> %65 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x64x!tt.ptr> - %66 = tt.addptr %65, %64 : tensor<64x64x!tt.ptr> + %66 = tt.addptr %65, %64 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %67 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32> %68 = tt.splat %arg3 : (i32) -> tensor<64x1xi32> %69 = arith.cmpi slt, %67, %68 : tensor<64x1xi32> diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index efba86b90..5a8cd860b 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -51,8 +51,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %b_op = triton_gpu.convert_layout %b : (tensor<32x128xf16, #B>) -> tensor<32x128xf16, #B_OP> %c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> - %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> - %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> %next_a = triton_gpu.convert_layout %next_a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>