diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 69a301c31..ce74c8f5e 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -29,22 +29,26 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffec /*DeclareOpInterfaceMethods*/]> { let summary = "Cast int64 to pointer"; - let arguments = (ins I64Tensor:$from); + let arguments = (ins TT_I64Like:$from); - let results = (outs TT_PtrTensor:$result); + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)"; } def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffect, - /*DeclareOpInterfaceMethods*/]> { + /*DeclareOpInterfaceMethods*/]> { let summary = "Cast pointer to int64"; - let arguments = (ins TT_PtrTensor:$from); + let arguments = (ins TT_PtrLike:$from); - let results = (outs I64Tensor:$result); + let results = (outs TT_I64Like:$result); + + let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)"; } def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect, - /*DeclareOpInterfaceMethods*/]> { + /*DeclareOpInterfaceMethods*/]> { let summary = "Floating point casting for custom types"; let description = [{ @@ -54,9 +58,11 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect, BF8 <-> F8, FP16, FP32 }]; - let arguments = (ins TT_FloatTensor:$from); + let arguments = (ins TT_FloatLike:$from); - let results = (outs TT_FloatTensor:$result); + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)"; // TODO: We need a verifier here. } @@ -127,16 +133,16 @@ def TT_StoreOp : TT_Op<"store", let hasCanonicalizer = 1; } -def TT_GEPOp : TT_Op<"getelementptr", +def TT_AddPtrOp : TT_Op<"addptr", [NoSideEffect, SameOperandsAndResultShape, TypesMatchWith<"result type matches ptr type", "result", "ptr", "$_self">, TypesMatchWith<"result shape matches offset shape", "result", "offset", "getI32SameShape($_self)">]> { - let arguments = (ins TT_PtrTensor:$ptr, I32Tensor:$offset); + let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset); - let results = (outs TT_PtrTensor:$result); + let results = (outs TT_PtrLike:$result); let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)"; } @@ -278,7 +284,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> { return $old }]; - let arguments = (ins TT_Pointer:$ptr, TT_Type:$cmp, TT_Type:$val); + let arguments = (ins TT_Ptr:$ptr, TT_Type:$cmp, TT_Type:$val); let results = (outs TT_Type:$result); } @@ -318,7 +324,7 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> { let arguments = (ins I32Attr:$start, I32Attr:$end); - let results = (outs TT_IntegerTensor:$result); + let results = (outs TT_IntTensor:$result); let assemblyFormat = "attr-dict `:` type($result)"; } diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td index b5238996e..81184c91c 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypes.td +++ b/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -12,18 +12,36 @@ class TritonTypeDef let mnemonic = _mnemonic; } +// Floating-point Type def F8 : TritonTypeDef<"Float8", "f8">; def BF8 : TritonTypeDef<"BFloat8", "bf8">; def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">; def TT_FloatTensor : TensorOf<[TT_Float]>; +def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; -// IntegerType +// Boolean Type +// TT_Bool -> I1 +def TT_BoolTensor : TensorOf<[I1]>; +def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>; + +// Integer Type def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">; -def TT_IntegerTensor : TensorOf<[TT_Int]>; +def TT_IntTensor : TensorOf<[TT_Int]>; +def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>; -// PointerType -def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> { +// I32 Type +// TT_I32 -> I32 +// TT_I32Tensor -> I32Tensor +def TT_I32Like: AnyTypeOf<[I32, I32Tensor]>; + +// I64 Type +// TT_I64 -> I64 +// TT_I64Tensor -> I64Tensor +def TT_I64Like: AnyTypeOf<[I64, I64Tensor]>; + +// Pointer Type +def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> { let summary = "pointer type"; let description = [{ @@ -43,12 +61,12 @@ def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> { let skipDefaultBuilders = 1; } -def TT_PtrTensor : TensorOf<[TT_Pointer]>; +def TT_PtrTensor : TensorOf<[TT_Ptr]>; +def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>; -def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>; +def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntTensor]>; def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>; -def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor, - TT_Pointer, TT_PtrTensor]>; +def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike]>; #endif diff --git a/include/triton/Dialect/Triton/Transforms/Passes.td b/include/triton/Dialect/Triton/Transforms/Passes.td index 2515057fe..8f77aed77 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.td +++ b/include/triton/Dialect/Triton/Transforms/Passes.td @@ -8,7 +8,7 @@ def TritonCombineOps : Pass let description = [{ dot(a, b, 0) + c => dot(a, b, c) - gep(gep(ptr, idx0), idx1) => gep(ptr, AddI(idx0, idx1)) + addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1)) select(cond, load(ptrs, broadcast(cond), ???), other) => load(ptrs, broadcast(cond), other) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 8f9e4b513..04708f639 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -10,12 +10,6 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType -def TT_BoolTensor : TensorOf<[I1]>; - -def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>; -def TT_IntegerLike : AnyTypeOf<[TT_Int, TT_IntegerTensor]>; -def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; - class TTG_Op traits = []> : Op; @@ -48,8 +42,8 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> { let description = [{}]; let arguments = (ins Arith_CmpIPredicateAttr:$predicate, - TT_IntegerLike:$lhs, - TT_IntegerLike:$rhs); + TT_IntLike:$lhs, + TT_IntLike:$rhs); let results = (outs TT_BoolLike:$result); } @@ -66,7 +60,7 @@ def TTG_CmpFOp : TTG_Op<"cmpf"> { let results = (outs TT_BoolLike:$result); } -def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", +def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", [SameVariadicOperandSize, MemoryEffects<[MemRead, MemWrite]>, TypesMatchWith<"infer mask type from src type", @@ -94,7 +88,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", * other: optional tensor-rank number of other tensors which specify what values are inserted into the `$dst` tensor if the corresponding element of the `$mask` tensor is false. - + In the future, we may decompose this operation into a sequence of: * `async` operation to specify a sequence of asynchronous operations @@ -191,7 +185,7 @@ def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect]> { Note: This op can be repalced to a `bufferization.alloc_tensor` in LLVM 16. }]; - let assemblyFormat = [{attr-dict `:` type($result)}]; + let assemblyFormat = [{attr-dict `:` type($result)}]; let results = (outs TT_Tensor:$result); diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index a6b280cc1..441530d25 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -126,7 +126,7 @@ ChangeResult AxisInfoAnalysis::visitOperation( } } // Addition - if (llvm::isa(op)) { + if (llvm::isa(op)) { auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) { return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)), gcd(lhs.getConstancy(d), rhs.getContiguity(d))); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 76ce32091..c6f0efbd9 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -701,7 +701,7 @@ struct StoreOpConversion const int numVecs = numElems / vec; for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { - // TODO: optimization when ptr is GEP with constant offset + // TODO: optimization when ptr is AddPtr with constant offset size_t in_off = 0; const int maxWordWidth = std::max(32, valueElemNbits); @@ -1173,12 +1173,13 @@ struct GetProgramIdOpConversion } }; -struct GEPOpConversion : public ConvertTritonGPUOpToLLVMPattern { +struct AddPtrOpConversion + : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< - triton::GEPOp>::ConvertTritonGPUOpToLLVMPattern; + triton::AddPtrOp>::ConvertTritonGPUOpToLLVMPattern; LogicalResult - matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor, + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto resultTy = op.getType().dyn_cast(); @@ -1298,7 +1299,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add(typeConverter, benefit); patterns.add(typeConverter, numWarps, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, analysis, benefit); patterns.add(typeConverter, benefit); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 19cefee86..de0c409e4 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -323,7 +323,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, patterns.add< // TODO: view should have custom pattern that views the layout TritonGenericPattern, TritonGenericPattern, TritonBroadcastPattern, - TritonGenericPattern, TritonReducePattern, + TritonGenericPattern, TritonReducePattern, TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern>( typeConverter, context); diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 501f19cf8..4628f5d1e 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -15,7 +15,7 @@ static Type getI1SameShape(Type type) { if (auto tensorType = type.dyn_cast()) return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding()); - return Type(); + return i1Type; } static Type getI32SameShape(Type type) { @@ -23,7 +23,7 @@ static Type getI32SameShape(Type type) { if (auto tensorType = type.dyn_cast()) return RankedTensorType::get(tensorType.getShape(), i32Type, tensorType.getEncoding()); - return Type(); + return i32Type; } static Type getPointerTypeFromTensor(Type type) { diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 58233e015..249a3c075 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -194,7 +194,7 @@ public: patterns.add(context); // %} patterns.add(context); - patterns.add(context); + patterns.add(context); patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index 8881b33c2..1b84cc7f3 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -30,12 +30,12 @@ def CombineDotAddFRevPattern : Pat< [(Constraint> $c)]>; -// gep(gep(%ptr, %idx0), %idx1) => gep(%ptr, AddI(%idx0, %idx1)) +// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1)) // Note: leave (sub %c0, %c0) canceling to ArithmeticDialect // (ref: ArithmeticCanonicalization.td) -def CombineGEPPattern : Pat< - (TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1), - (TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>; +def CombineAddPtrPattern : Pat< + (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1), + (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>; // broadcast(cst) => cst def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">; diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 42a8f61dd..d879f451e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -203,7 +203,7 @@ bool tryLegalizeOp(Operation *op, DenseSet toPreserve, targetType.getEncoding()); }; bool hasSameTypes = op->getDialect()->getNamespace() == "arith" || - isa(op); + isa(op); if (hasSameTypes) { // replace argument types for (auto arg : llvm::enumerate(op->getOperands())) { @@ -440,4 +440,4 @@ public: std::unique_ptr mlir::createTritonGPUCombineOpsPass() { return std::make_unique(); -} \ No newline at end of file +} diff --git a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp index 7bf4143f6..1acdf915c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp @@ -70,7 +70,7 @@ private: if (auto storeOp = llvm::dyn_cast(op)) { // TODO: fill this } - if (auto gepOp = llvm::dyn_cast(op)) { + if (auto addptrOp = llvm::dyn_cast(op)) { // TODO: fill this } // Triton builtin Ops diff --git a/python/src/triton.cc b/python/src/triton.cc index 9e3485df8..739dcdc69 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1245,13 +1245,13 @@ void init_triton_ir(py::module &&m) { return mlir::Value( self.create(loc, lhs, rhs)); }) - // GEP - .def("create_gep", + // AddPtr (similar to GEP) + .def("create_addptr", [](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &offset) -> mlir::Value { auto loc = self.getUnknownLoc(); - return self.create(loc, ptr.getType(), ptr, - offset); + return self.create(loc, ptr.getType(), ptr, + offset); }) // Comparison (int) .def("create_icmpSLE", diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index b2a714be0..dd3dc0aa5 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -121,7 +121,7 @@ def add(input: tl.tensor, if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): input, other = other, input if input_scalar_ty.is_ptr(): - return tl.tensor(builder.create_gep(input.handle, other.handle), input.type) + return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type) # float + float elif input_scalar_ty.is_floating(): return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) @@ -138,7 +138,7 @@ def sub(input: tl.tensor, scalar_ty = input.type.scalar # ptr - offset if scalar_ty.is_ptr(): - return tl.tensor(builder.create_gep(input.handle, minus(other, builder).handle), + return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type) # float - float if scalar_ty.is_floating(): diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 10f5c1e42..550455ccc 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -26,8 +26,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> - %next_a_ptr = tt.getelementptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> - %next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } return diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index c70a81c8e..351866f7d 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -18,7 +18,7 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x1x!tt.ptr> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1] - %6 = tt.getelementptr %5, %4 : tensor<128x1x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr> // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1] %7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] @@ -26,13 +26,13 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1] %9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32> // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1] - %10 = tt.getelementptr %8, %9 : tensor<128x128x!tt.ptr> + %10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr> // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1] %11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr> // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] - %13 = tt.getelementptr %12, %11 : tensor<128x1x!tt.ptr> + %13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr> // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1] %14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] @@ -44,7 +44,7 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [128, 1] %18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32> // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] - %19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr> + %19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] %20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32> tt.store %19, %20, %cst : tensor<128x128xf32> diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 1daaceec7..7cfecade2 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -30,8 +30,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> - %next_a_ptr = tt.getelementptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> - %next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } return diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index f611b4824..1f7449c78 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -28,8 +28,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B // CHECK: Membar 13 %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> - %next_a_ptr = tt.getelementptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> - %next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } return diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir new file mode 100644 index 000000000..20744ecaa --- /dev/null +++ b/test/Conversion/triton_ops.mlir @@ -0,0 +1,55 @@ +// RUN: triton-opt %s | FileCheck %s + +func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) { + // scalar -> scalar + // CHECK: i64 -> !tt.ptr + %0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr + // CHECK: !tt.ptr -> i64 + %1 = tt.ptr_to_int %scalar_ptr : !tt.ptr -> i64 + // CHECK: f32 -> f16 + %2 = tt.fp_to_fp %scalar_f32 : f32 -> f16 + + // 0D tensor -> 0D tensor + %tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr) -> tensor> + %tensor_f32_0d = tt.splat %scalar_f32 : (f32) -> tensor + %tensor_i64_0d = tt.splat %scalar_i64 : (i64) -> tensor + + // CHECK: tensor -> tensor> + %3 = tt.int_to_ptr %tensor_i64_0d : tensor -> tensor> + // CHECK: tensor> -> tensor + %4 = tt.ptr_to_int %tensor_ptr_0d : tensor> -> tensor + // CHECK: tensor -> tensor + %5 = tt.fp_to_fp %tensor_f32_0d : tensor -> tensor + + // 1D tensor -> 1D tensor + %tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr) -> tensor<16x!tt.ptr> + %tensor_f32_1d = tt.splat %scalar_f32 : (f32) -> tensor<16xf32> + %tensor_i64_1d = tt.splat %scalar_i64 : (i64) -> tensor<16xi64> + + // CHECK: tensor<16xi64> -> tensor<16x!tt.ptr> + %6 = tt.int_to_ptr %tensor_i64_1d : tensor<16xi64> -> tensor<16x!tt.ptr> + // CHECK: tensor<16x!tt.ptr> -> tensor<16xi64> + %7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr> -> tensor<16xi64> + // CHECK: tensor<16xf32> -> tensor<16xf16> + %8 = tt.fp_to_fp %tensor_f32_1d : tensor<16xf32> -> tensor<16xf16> + return +} + +func @addptr_ops(%scalar_ptr: !tt.ptr, %scalar_i32: i32) { + // scalar -> scalar + // CHECK: !tt.ptr + %0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr + + // 0D tensor -> 0D tensor + %tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr) -> tensor> + %tensor_i32_0d = tt.splat %scalar_i32 : (i32) -> tensor + // CHECK: tensor> + %1 = tt.addptr %tensor_ptr_0d, %tensor_i32_0d : tensor> + + // 1D tensor -> 1D tensor + %tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr) -> tensor<16x!tt.ptr> + %tensor_i32_1d = tt.splat %scalar_i32 : (i32) -> tensor<16xi32> + // CHECK: tensor<16x!tt.ptr> + %2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr> + return +} diff --git a/test/Conversion/ops.mlir b/test/Conversion/triton_to_tritongpu.mlir similarity index 100% rename from test/Conversion/ops.mlir rename to test/Conversion/triton_to_tritongpu.mlir diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 9b67f262a..1d84f6275 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -82,9 +82,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0> %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> - %6 = tt.getelementptr %5, %4 : tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0> %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> - %8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0> // CHECK: ld.global.v4.b32 %9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> @@ -92,7 +92,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> - %13 = tt.getelementptr %12, %4 : tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0> // Store 4 elements to global // CHECK: st.global.b32.v4 @@ -104,7 +104,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { // TODO: Add a testcase to verify the optimization when ptr of the LoadOp -// is from a GEP with const idx +// is from an addptr with const idx // ----- @@ -187,11 +187,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: basic_gep - func @basic_gep(%arg0 : tensor<256x!tt.ptr,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { + // CHECK-LABEL: basic_addptr + func @basic_addptr(%arg0 : tensor<256x!tt.ptr,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { // CHECK: llvm.getelementptr // CHECK: llvm.getelementptr - %0 = tt.getelementptr %arg0, %arg1 : tensor<256x!tt.ptr, #blocked0> + %0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr, #blocked0> return } } diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 3b84966ba..503cc9a26 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -22,8 +22,8 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32 return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32> } -// CHECK-LABEL: @test_combine_gep_pattern -func @test_combine_gep_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> { +// CHECK-LABEL: @test_combine_addptr_pattern +func @test_combine_addptr_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> { %off0 = arith.constant 10 : i32 %off1 = arith.constant 15 : i32 @@ -37,9 +37,9 @@ func @test_combine_gep_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> { %idx0 = tt.broadcast %off0 : (i32) -> tensor<8xi32> %idx1 = tt.broadcast %off1 : (i32) -> tensor<8xi32> - // CHECK-NEXT: %1 = tt.getelementptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr> - %ptr0 = tt.getelementptr %base_, %idx0 : tensor<8x!tt.ptr> - %ptr1 = tt.getelementptr %ptr0, %idx1 : tensor<8x!tt.ptr> + // CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr> + %ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr> + %ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr> return %ptr1 : tensor<8x!tt.ptr> } diff --git a/test/Triton/vecadd.mlir b/test/Triton/vecadd.mlir index 8f3642373..391e80650 100644 --- a/test/Triton/vecadd.mlir +++ b/test/Triton/vecadd.mlir @@ -11,9 +11,9 @@ module { %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32> %6 = arith.cmpi slt, %4, %5 : tensor<256xi32> %7 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr> - %8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr> %9 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr> - %10 = tt.getelementptr %9, %4 : tensor<256x!tt.ptr> + %10 = tt.addptr %9, %4 : tensor<256x!tt.ptr> %cst = arith.constant 0.000000e+00 : f32 %11 = tt.broadcast %cst : (f32) -> tensor<256xf32> %c0_i32 = arith.constant 0 : i32 @@ -31,13 +31,13 @@ module { %22 = arith.addf %19, %21 : tensor<256xf32> %23 = arith.addf %arg7, %22 : tensor<256xf32> %24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32> - %25 = tt.getelementptr %arg8, %24 : tensor<256x!tt.ptr> + %25 = tt.addptr %arg8, %24 : tensor<256x!tt.ptr> %26 = tt.broadcast %arg5 : (i32) -> tensor<256xi32> - %27 = tt.getelementptr %arg9, %26 : tensor<256x!tt.ptr> + %27 = tt.addptr %arg9, %26 : tensor<256x!tt.ptr> scf.yield %23, %25, %27 : tensor<256xf32>, tensor<256x!tt.ptr>, tensor<256x!tt.ptr> } %16 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr> - %17 = tt.getelementptr %16, %4 : tensor<256x!tt.ptr> + %17 = tt.addptr %16, %4 : tensor<256x!tt.ptr> tt.store %17, %15#0, %6 : tensor<256xf32> return } @@ -57,9 +57,9 @@ module { // %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> // %6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu<"coalesced encoding">>, tensor<256xi32, #triton_gpu<"coalesced encoding">>) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> // %7 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %8 = tt.getelementptr %7, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // %9 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %10 = tt.getelementptr %9, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %10 = tt.addptr %9, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // %11 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %12 = arith.index_cast %arg4 : i32 to index // %13 = arith.cmpi slt, %c0, %12 : index @@ -72,9 +72,9 @@ module { // %20 = arith.andi %6, %19 : tensor<256xi1, #triton_gpu<"coalesced encoding">> // %21 = triton_gpu.copy_async %10, %20, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %22 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %23 = tt.getelementptr %8, %22, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %23 = tt.addptr %8, %22, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // %24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %25 = tt.getelementptr %10, %24, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %25 = tt.addptr %10, %24, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // %26 = arith.cmpi slt, %c32, %12 : index // %27 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %28 = tt.broadcast %26 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> @@ -85,9 +85,9 @@ module { // %33 = arith.andi %6, %32 : tensor<256xi1, #triton_gpu<"coalesced encoding">> // %34 = triton_gpu.copy_async %25, %33, %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %35 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %36 = tt.getelementptr %23, %35, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %36 = tt.addptr %23, %35, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // %37 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %38 = tt.getelementptr %25, %37, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %38 = tt.addptr %25, %37, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // %39 = arith.cmpi slt, %c64, %12 : index // %40 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %41 = tt.broadcast %39 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> @@ -98,16 +98,16 @@ module { // %46 = arith.andi %6, %45 : tensor<256xi1, #triton_gpu<"coalesced encoding">> // %47 = triton_gpu.copy_async %38, %46, %44 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %48 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %49 = tt.getelementptr %36, %48, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %49 = tt.addptr %36, %48, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // %50 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %51 = tt.getelementptr %38, %50, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %51 = tt.addptr %38, %50, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // %52:12 = scf.for %arg6 = %c0 to %12 step %c32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10, %arg10 = %17, %arg11 = %30, %arg12 = %43, %arg13 = %21, %arg14 = %34, %arg15 = %47, %arg16 = %51, %arg17 = %49, %arg18 = %c64) -> (tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, index) { // %55 = arith.addf %arg10, %arg13 : tensor<256xf32, #triton_gpu<"coalesced encoding">> // %56 = arith.addf %arg7, %55 : tensor<256xf32, #triton_gpu<"coalesced encoding">> // %57 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %58 = tt.getelementptr %arg8, %57, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %58 = tt.addptr %arg8, %57, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // %59 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %60 = tt.getelementptr %arg9, %59, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %60 = tt.addptr %arg9, %59, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // %61 = arith.addi %arg18, %c32 : index // %62 = arith.cmpi slt, %61, %12 : index // %63 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> @@ -117,13 +117,13 @@ module { // %67 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %68 = triton_gpu.copy_async %arg16, %65, %67 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> // %69 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %70 = tt.getelementptr %arg17, %69, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %70 = tt.addptr %arg17, %69, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // %71 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %72 = tt.getelementptr %arg16, %71, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %72 = tt.addptr %arg16, %71, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // scf.yield %56, %58, %60, %arg11, %arg12, %66, %arg14, %arg15, %68, %72, %70, %61 : tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, index // } // %53 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %54 = tt.getelementptr %53, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> +// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // tt.store %54, %52#0, %6 : tensor<256xf32, #triton_gpu<"coalesced encoding">> // return // } diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index d6f5c8527..1dc46a1a6 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -28,20 +28,20 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1> %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> - %5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> + %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> %6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> - %10 = tt.getelementptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1> + %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1> %11 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> - %12 = tt.getelementptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1> + %12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1> %13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2> %14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2> %15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> - %18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1> + %18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1> %19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1> tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1> return diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 4f1b20bf3..8a0aabf67 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -70,7 +70,7 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt // CHECK: %5 = arith.muli %2, %3 : tensor<64x1xi32, [[row_layout]]> // CHECK: %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>> // CHECK: %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>> - // CHECK: %8 = tt.getelementptr %4, %5 : tensor<64x1x!tt.ptr, [[row_layout]]> + // CHECK: %8 = tt.addptr %4, %5 : tensor<64x1x!tt.ptr, [[row_layout]]> // CHECK: %9 = tt.expand_dims %7 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>>) -> tensor<1x64xi32, [[row_layout]]> // CHECK: %10 = tt.broadcast %8 : (tensor<64x1x!tt.ptr, [[row_layout]]>) -> tensor<64x64x!tt.ptr, [[row_layout]]> // CHECK: %11 = tt.broadcast %9 : (tensor<1x64xi32, [[row_layout]]>) -> tensor<64x64xi32, [[row_layout]]> @@ -78,13 +78,13 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt // CHECK: %13 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[col_layout]]}>>) -> tensor<64x1xi32, [[col_layout]]> // CHECK: %14 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>>) -> tensor<1x64xi32, [[col_layout]]> // CHECK: %15 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, [[col_layout]]> - // CHECK: %16 = tt.getelementptr %12, %13 : tensor<64x1x!tt.ptr, [[col_layout]]> + // CHECK: %16 = tt.addptr %12, %13 : tensor<64x1x!tt.ptr, [[col_layout]]> // CHECK: %17 = arith.muli %14, %15 : tensor<1x64xi32, [[col_layout]]> // CHECK: %18 = tt.broadcast %16 : (tensor<64x1x!tt.ptr, [[col_layout]]>) -> tensor<64x64x!tt.ptr, [[col_layout]]> // CHECK: %19 = tt.broadcast %17 : (tensor<1x64xi32, [[col_layout]]>) -> tensor<64x64xi32, [[col_layout]]> - // CHECK: %20 = tt.getelementptr %10, %11 : tensor<64x64x!tt.ptr, [[row_layout]]> + // CHECK: %20 = tt.addptr %10, %11 : tensor<64x64x!tt.ptr, [[row_layout]]> // CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> - // CHECK: %22 = tt.getelementptr %18, %19 : tensor<64x64x!tt.ptr, [[col_layout]]> + // CHECK: %22 = tt.addptr %18, %19 : tensor<64x64x!tt.ptr, [[col_layout]]> // CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]> // CHECK: tt.store %22, %23, %cst_1 : tensor<64x64xf32, [[col_layout]]> // CHECK: return @@ -95,20 +95,20 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt %2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1> %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> - %5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> + %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> %6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> - %10 = tt.getelementptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1> + %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1> %11 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> - %12 = tt.getelementptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1> + %12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1> %13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2> %14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2> %15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> - %18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1> + %18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1> %19 = triton_gpu.convert_layout %10 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked3> %20 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3> %21 = triton_gpu.convert_layout %cst : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3> @@ -127,7 +127,7 @@ func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %ar // CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr, [[row_layout]]>) // CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[row_layout]]> // CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[row_layout]]> - // CHECK-NEXT: {{.*}} = tt.getelementptr {{.*}} : tensor<64x64x!tt.ptr, [[row_layout]]> + // CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr, [[row_layout]]> // CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr, [[row_layout]]> // CHECK-NEXT: } // CHECK-NEXT: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout_novec]]> @@ -143,12 +143,12 @@ func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %ar %2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1> %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> - %5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> + %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> %6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> - %10 = tt.getelementptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1> + %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1> %11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1>) { %23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked3> %24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3> @@ -156,17 +156,17 @@ func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %ar %26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, #blocked3> %27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1> %28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1> - %29 = tt.getelementptr %arg7, %cst_0 : tensor<64x64x!tt.ptr, #blocked1> + %29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr, #blocked1> scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1> } %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> - %13 = tt.getelementptr %12, %1 : tensor<64x1x!tt.ptr, #blocked1> + %13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr, #blocked1> %14 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2> %15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2> %16 = tt.broadcast %13 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %17 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %18 = triton_gpu.convert_layout %17 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> - %19 = tt.getelementptr %16, %18 : tensor<64x64x!tt.ptr, #blocked1> + %19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr, #blocked1> %20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1> %22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1> @@ -190,17 +190,17 @@ func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr> %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> %11 = arith.addi %4, %5 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %12 = tt.getelementptr %8, %9 : tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %12 = tt.addptr %8, %9 : tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> %13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> %14 = triton_gpu.convert_layout %13 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>> - %15 = tt.getelementptr %10, %11 : tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %15 = tt.addptr %10, %11 : tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> %16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> %17 = triton_gpu.convert_layout %16 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>> %18 = arith.addf %14, %17 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>> %19 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> %20 = arith.addi %2, %3 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %21 = tt.getelementptr %19, %20 : tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %21 = tt.addptr %19, %20 : tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> %22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> tt.store %21, %22 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> return -} \ No newline at end of file +} diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 3e147f8ef..320e2116a 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -55,8 +55,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> - %next_a_ptr = tt.getelementptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> - %next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } return @@ -112,8 +112,8 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> - %next_a_ptr = tt.getelementptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> - %next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } } @@ -161,7 +161,7 @@ func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> - %next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } return diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir index ed48ac371..fc15bdc04 100644 --- a/test/TritonGPU/matmul.mlir +++ b/test/TritonGPU/matmul.mlir @@ -46,7 +46,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6 %31 = tt.broadcast %29 : (tensor<1x64xi32>) -> tensor<64x64xi32> %32 = arith.addi %30, %31 : tensor<64x64xi32> %33 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x64x!tt.ptr> - %34 = tt.getelementptr %33, %32 : tensor<64x64x!tt.ptr> + %34 = tt.addptr %33, %32 : tensor<64x64x!tt.ptr> %35 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32> %36 = tt.splat %arg8 : (i32) -> tensor<64x1xi32> %37 = arith.muli %35, %36 : tensor<64x1xi32> @@ -57,7 +57,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6 %42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<64x64xi32> %43 = arith.addi %41, %42 : tensor<64x64xi32> %44 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x64x!tt.ptr> - %45 = tt.getelementptr %44, %43 : tensor<64x64x!tt.ptr> + %45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr> %46 = arith.index_cast %arg5 : i32 to index %47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>) { %76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32> @@ -66,10 +66,10 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6 %79 = arith.addf %arg13, %78 : tensor<64x64xf32> %80 = arith.muli %arg7, %c64_i32 : i32 %81 = tt.splat %80 : (i32) -> tensor<64x64xi32> - %82 = tt.getelementptr %arg14, %81 : tensor<64x64x!tt.ptr> + %82 = tt.addptr %arg14, %81 : tensor<64x64x!tt.ptr> %83 = arith.muli %arg8, %c64_i32 : i32 %84 = tt.splat %83 : (i32) -> tensor<64x64xi32> - %85 = tt.getelementptr %arg15, %84 : tensor<64x64x!tt.ptr> + %85 = tt.addptr %arg15, %84 : tensor<64x64x!tt.ptr> scf.yield %79, %82, %85 : tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr> } %48 = arith.muli %12, %c64_i32 : i32 @@ -90,7 +90,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6 %63 = tt.broadcast %61 : (tensor<1x64xi32>) -> tensor<64x64xi32> %64 = arith.addi %62, %63 : tensor<64x64xi32> %65 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x64x!tt.ptr> - %66 = tt.getelementptr %65, %64 : tensor<64x64x!tt.ptr> + %66 = tt.addptr %65, %64 : tensor<64x64x!tt.ptr> %67 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32> %68 = tt.splat %arg3 : (i32) -> tensor<64x1xi32> %69 = arith.cmpi slt, %67, %68 : tensor<64x1xi32> @@ -103,4 +103,4 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6 tt.store %66, %47#0, %75 : tensor<64x64xf32> return } -} \ No newline at end of file +}