diff --git a/include/triton/Conversion/Passes.td b/include/triton/Conversion/Passes.td index 086cf2dd8..e78b4dc4f 100644 --- a/include/triton/Conversion/Passes.td +++ b/include/triton/Conversion/Passes.td @@ -38,6 +38,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" "mlir::gpu::GPUDialect", "mlir::scf::SCFDialect", "mlir::LLVM::LLVMDialect", + "mlir::tensor::TensorDialect", "mlir::triton::TritonDialect", "mlir::triton::gpu::TritonGPUDialect", "mlir::NVVM::NVVMDialect", diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index d25a4d4e6..3c2953fe1 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -1,7 +1,6 @@ #ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ #define TRITON_DIALECT_TRITON_IR_DIALECT_H_ -#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" diff --git a/include/triton/Dialect/Triton/IR/TritonDialect.td b/include/triton/Dialect/Triton/IR/TritonDialect.td index ea82bedd8..07b069e14 100644 --- a/include/triton/Dialect/Triton/IR/TritonDialect.td +++ b/include/triton/Dialect/Triton/IR/TritonDialect.td @@ -27,7 +27,6 @@ def Triton_Dialect : Dialect { "math::MathDialect", "StandardOpsDialect", "scf::SCFDialect", - "gpu::GPUDialect", // Since LLVM 15 // "cf::ControlFlowDialect", diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 1d3312637..33c7d889f 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -2,6 +2,7 @@ #define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" @@ -9,6 +10,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" +#include "triton/Dialect/TritonGPU/IR/Traits.h" #define GET_ATTRDEF_CLASSES #include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" diff --git a/include/triton/Dialect/TritonGPU/IR/Traits.h b/include/triton/Dialect/TritonGPU/IR/Traits.h new file mode 100644 index 000000000..44def9580 --- /dev/null +++ b/include/triton/Dialect/TritonGPU/IR/Traits.h @@ -0,0 +1,31 @@ +#ifndef TRITON_GPU_IR_TRAITS_H_ +#define TRITON_GPU_IR_TRAITS_H_ + +#include "mlir/IR/OpDefinition.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace OpTrait { + +// These functions are out-of-line implementations of the methods in the +// corresponding trait classes. This avoids them being template +// instantiated/duplicated. +namespace impl { +LogicalResult verifyResultsAreSharedEncoding(Operation *op); +} // namespace impl + +template +class ResultsAreSharedEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyResultsAreSharedEncoding(op); + } +}; + +} // namespace OpTrait +} // namespace mlir + +#endif diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td index d875f3c60..87ec1d36c 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -16,7 +16,8 @@ def TritonGPU_Dialect : Dialect { let dependentDialects = [ "triton::TritonDialect", - "mlir::gpu::GPUDialect" + "mlir::gpu::GPUDialect", + "tensor::TensorDialect", ]; let extraClassDeclaration = [{ diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index b88b80e7d..4a8824193 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -10,6 +10,8 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">; + class TTG_Op traits = []> : Op; @@ -75,7 +77,8 @@ def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> { def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", - [SameVariadicOperandSize, + [AttrSizedOperandSegments, + ResultsAreSharedEncoding, // MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should? NoSideEffect, TypesMatchWith<"infer mask type from src type", @@ -93,6 +96,10 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", It returns a copy of `$dst` with the proper slice updated asynchronously with the value of `$src`. This operation is non-blocking, and `$results` will have the updated value after the corresponding async_wait. + When converting from `tt.load` to `triton_gpu.insert_slice_async`, the `$evict`, `$cache`, and `$isVolatile` fields + might be ignored on certain hardware. For example, on NVIDIA GPUs, the cache policy is determined by the backend, + and `$evict` and `$isVolatile` are ignored because they apply to L1 cache only. + The insert_slice_async operation supports the following arguments: * src: the tensor that is inserted. @@ -149,48 +156,9 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", let parser = [{ return parseInsertSliceAsyncOp(parser, result); }]; let printer = [{ return printInsertSliceAsyncOp(p, *this); }]; - - // result needs to be of shared layout - let verifier = [{ return ::verify(*this); }]; } -def TTG_ExtractSliceOp : TTG_Op<"extract_slice", [NoSideEffect, InferTypeOpInterface]> { - let summary = "extract slice"; - let description = [{ - The "extract_slice" operation extracts a `$result` tensor from a `$src` tensor as - specified by the operation's `$index` and `$axis` arguments. - - The extract_slice operation supports the following arguments: - - * src: the tensor that is extracted from. - * index: the index at the given `$axis` from which the `$src` tensor is extracted - - Example: - - ``` - // Rank-reducing extract_slice. - %1 = tensor.extract_slice %0, %index {axis = 0} : tensor<8x16x4xf32> -> tensor<1x16x4xf32> - ``` - }]; - - let arguments = (ins TT_Tensor:$src, I32:$index, I32Attr:$axis); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = [{$src `,` $index attr-dict `:` type($src) `->` type($result)}]; - - let extraClassDeclaration = [{ - static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, - ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, - ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes); - }]; - - // result needs to be of shared layout - let verifier = [{ return ::verify(*this); }]; -} - -def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect]> { +def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect, ResultsAreSharedEncoding]> { let summary = "allocate tensor"; let description = [{ @@ -203,9 +171,6 @@ def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect]> { let assemblyFormat = [{attr-dict `:` type($result)}]; let results = (outs TT_Tensor:$result); - - // result needs to be of shared layout - let verifier = [{ return ::verify(*this); }]; } #endif diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp index 31287f5c0..8deebdf09 100644 --- a/lib/Analysis/Alias.cpp +++ b/lib/Analysis/Alias.cpp @@ -1,4 +1,5 @@ #include "triton/Analysis/Alias.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -24,18 +25,18 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation( if (maybeSharedAllocationOp(op)) { // These ops may allocate a new shared memory buffer. auto result = op->getResult(0); - if (isSharedEncoding(result)) { - // FIXME(Keren): extract and insert are always alias for now - if (auto extractSliceOp = dyn_cast(op)) { - // extract_slice %src, %index - aliasInfo = AliasInfo(operands[0]->getValue()); - } else if (auto insertSliceOp = - dyn_cast(op)) { - // insert_slice_async %src, %dst, %index - aliasInfo = AliasInfo(operands[1]->getValue()); - } else { - aliasInfo.insert(result); - } + // FIXME(Keren): extract and insert are always alias for now + if (auto extractSliceOp = dyn_cast(op)) { + // extract_slice %src + aliasInfo = AliasInfo(operands[0]->getValue()); + pessimistic = false; + } else if (auto insertSliceOp = + dyn_cast(op)) { + // insert_slice_async %src, %dst, %index + aliasInfo = AliasInfo(operands[1]->getValue()); + pessimistic = false; + } else if (isSharedEncoding(result)) { + aliasInfo.insert(result); pessimistic = false; } } diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index c13310160..51d4e0e3b 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -1,6 +1,7 @@ #include "triton/Analysis/Allocation.h" #include "mlir/Analysis/Liveness.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "triton/Analysis/Alias.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -76,13 +77,13 @@ SmallVector getScratchConfigForReduce(triton::ReduceOp op) { auto srcShape = srcTy.getShape(); auto axis = op.axis(); - bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension + bool fastReduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension SmallVector smemShape; for (auto d : srcShape) smemShape.push_back(d); - if (fast_reduce) { + if (fastReduce) { unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis]; smemShape[axis] = sizeInterWarps; } else { @@ -123,7 +124,7 @@ private: // For example: %a = scf.if -> yield // %a must be allocated elsewhere by other operations. // FIXME(Keren): extract and insert are always alias for now - if (!maybeSharedAllocationOp(op) || isa(op) || + if (!maybeSharedAllocationOp(op) || isa(op) || isa(op)) { return; } diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index 6b03c9947..eab1636e5 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -2,6 +2,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" namespace mlir { @@ -43,8 +44,7 @@ void MembarAnalysis::dfsOperation(Operation *operation, void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, OpBuilder *builder) { if (isa(op) || isa(op) || isa(op) || - isa(op) || - isa(op)) { + isa(op) || isa(op)) { // Do not insert barriers before control flow operations and // alloc/extract/insert // alloc is an allocation op without memory write. diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index dcd3def89..5fcb9654d 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -24,7 +24,8 @@ bool maybeSharedAllocationOp(Operation *op) { mlir::TypeID::get() || dialect->getTypeID() == mlir::TypeID::get() || dialect->getTypeID() == - mlir::TypeID::get()); + mlir::TypeID::get() || + dialect->getTypeID() == mlir::TypeID::get()); } std::string getValueOperandName(Value value, AsmState &state) { diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 93ad97f80..17d189216 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" @@ -50,7 +51,8 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; } namespace { // Create a 32-bit integer constant. -Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) { +static Value createConstantI32(Location loc, PatternRewriter &rewriter, + int32_t v) { auto i32ty = rewriter.getIntegerType(32); return rewriter.create(loc, i32ty, IntegerAttr::get(i32ty, v)); @@ -63,17 +65,17 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) { } // Create a index type constant. -Value createIndexConstant(OpBuilder &builder, Location loc, +static Value createIndexConstant(OpBuilder &builder, Location loc, - TypeConverter *converter, int64_t value) { + TypeConverter *converter, int64_t value) { Type ty = converter->convertType(builder.getIndexType()); return builder.create(loc, ty, builder.getIntegerAttr(ty, value)); } // Create an integer constant of \param width bits. -Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, - int64_t value) { +static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, + short width, int64_t value) { Type ty = builder.getIntegerType(width); return builder.create(loc, ty, builder.getIntegerAttr(ty, value)); @@ -369,8 +371,8 @@ static T getLinearIndex(ArrayRef multiDimIndex, ArrayRef shape) { return linearIndex; } -Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Value val, Value pred) { +static Value storeShared(ConversionPatternRewriter &rewriter, Location loc, + Value ptr, Value val, Value pred) { MLIRContext *ctx = rewriter.getContext(); unsigned bits = val.getType().getIntOrFloatBitWidth(); const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); @@ -383,6 +385,50 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, return builder.launch(rewriter, loc, void_ty(ctx)); } +struct SharedMemoryObject { + Value base; // i32 ptr. The start address of the shared memory object. + // We need to store strides as Values but not integers because the + // extract_slice instruction can take a slice at artibary offsets. + // Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is + // 32, we need to let the instruction that uses $a to be aware of that. + // Otherwise, when we use $a, we only know that the shape of $a is 16x16. If + // we store strides into an attribute array of integers, the information + // cannot pass through block argument assignment because attributes are + // associated with operations but not Values. + // TODO(Keren): We may need to figure out a way to store strides as integers + // if we want to support more optimizations. + SmallVector + strides; // i32 int. The strides of the shared memory object. + + SharedMemoryObject(Value base, ArrayRef strides) + : base(base), strides(strides.begin(), strides.end()) {} + + SharedMemoryObject(Value base, ArrayRef shape, Location loc, + ConversionPatternRewriter &rewriter) + : base(base) { + auto stride = 1; + for (auto dim : llvm::reverse(shape)) { + this->strides.emplace_back(i32_val(stride)); + stride *= dim; + } + this->strides = llvm::to_vector<4>(llvm::reverse(this->strides)); + } + + SmallVector getElems() const { + SmallVector elems; + elems.push_back(base); + elems.append(strides.begin(), strides.end()); + return elems; + } + + SmallVector getTypes() const { + SmallVector types; + types.push_back(base.getType()); + types.append(strides.size(), IntegerType::get(base.getContext(), 32)); + return types; + } +}; + struct ConvertTritonGPUOpToLLVMPatternBase { static SmallVector getElementsFromStruct(Location loc, Value llvmStruct, @@ -489,6 +535,16 @@ public: return linear; } + Value dot(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef offsets, ArrayRef strides) const { + assert(offsets.size() == strides.size()); + Value ret = idx_val(0); + for (auto [offset, stride] : llvm::zip(offsets, strides)) { + ret = add(ret, mul(offset, stride)); + } + return ret; + } + // Get an index-base for each dimension for a \param blocked_layout. SmallVector emitBaseIndexForBlockedLayout(Location loc, @@ -671,6 +727,25 @@ public: return base; } + static SharedMemoryObject + getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, + ConversionPatternRewriter &rewriter) { + auto elems = getElementsFromStruct(loc, llvmStruct, rewriter); + return SharedMemoryObject(/*base=*/elems[0], + /*strides=*/{elems.begin() + 1, elems.end()}); + } + + static Value + getStructFromSharedMemoryObject(Location loc, + const SharedMemoryObject &smemObj, + ConversionPatternRewriter &rewriter) { + auto elems = smemObj.getElems(); + auto types = smemObj.getTypes(); + auto structTy = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); + return getStructFromElements(loc, elems, rewriter, structTy); + } + protected: const Allocation *allocation; Value smem; @@ -1734,46 +1809,63 @@ struct AllocTensorOpConversion auto resultTy = op.getType().dyn_cast(); auto llvmElemTy = getTypeConverter()->convertType(resultTy.getElementType()); - auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); - Value resultVal = - rewriter.create(loc, elemPtrTy, smemBase); - rewriter.replaceOp(op, resultVal); + auto elemPtrTy = ptr_ty(llvmElemTy, 3); + smemBase = bitcast(smemBase, elemPtrTy); + auto smemObj = + SharedMemoryObject(smemBase, resultTy.getShape(), loc, rewriter); + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); return success(); } }; struct ExtractSliceOpConversion - : public ConvertTritonGPUOpToLLVMPattern { + : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< - triton::gpu::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern; + tensor::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern; LogicalResult - matchAndRewrite(triton::gpu::ExtractSliceOp op, OpAdaptor adaptor, + matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // %dst = extract_slice %src[%offsets] Location loc = op->getLoc(); - auto srcTy = op.src().getType().dyn_cast(); + auto srcTy = op.source().getType().dyn_cast(); auto srcLayout = srcTy.getEncoding().dyn_cast(); assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion"); + assert(op.hasUnitStride() && + "Only unit stride supported by ExtractSliceOpConversion"); - // axis > 0 will result in non-contiguous memory access if the result - // tensor is an alias of the source tensor. - auto axis = op->getAttrOfType("axis").getInt(); - assert(axis == 0 && "extract_slice: Only axis=0 is supported for now"); - - // Example: - // %dst = extract_slice %src, %index {axis = 0} - // src.shape = [11, 2, 3, 4, 1] - // offset = %index * 2 * 3 * 4 * 1 - auto dstTy = op.getType().dyn_cast(); - auto base = product(dstTy.getShape()); - auto baseVal = createIndexAttrConstant( - rewriter, loc, getTypeConverter()->getIndexType(), base); - Value offset = mul(adaptor.index(), baseVal); - - auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); - auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); - Value resultVal = gep(elemPtrTy, adaptor.src(), offset); - rewriter.replaceOp(op, resultVal); + // newBase = base + offset + // Triton support either static and dynamic offsets + auto smemObj = + getSharedMemoryObjectFromStruct(loc, adaptor.source(), rewriter); + SmallVector offsetVals; + auto mixedOffsets = op.getMixedOffsets(); + for (auto i = 0; i < mixedOffsets.size(); ++i) { + if (op.isDynamicOffset(i)) + offsetVals.emplace_back(adaptor.offsets()[i]); + else + offsetVals.emplace_back(i32_val(op.getStaticOffset(i))); + } + // Compute the offset based on the original strides of the shared memory + // object + auto offset = dot(rewriter, loc, offsetVals, smemObj.strides); + // newShape = rank_reduce(shape) + // Triton only supports static tensor sizes + SmallVector strideVals; + auto staticSizes = op.static_sizes(); + for (auto i = 0; i < op.static_sizes().size(); ++i) { + if (op.getStaticSize(i) != 1) { + strideVals.emplace_back(smemObj.strides[i]); + } + } + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + auto elemPtrTy = ptr_ty(llvmElemTy, 3); + auto resTy = op.getType().dyn_cast(); + smemObj = + SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), strideVals); + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); return success(); } }; @@ -2262,8 +2354,9 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( Value src = op.src(); Value dst = op.result(); auto srcTy = src.getType().cast(); - auto dstTy = dst.getType().cast(); auto srcShape = srcTy.getShape(); + auto dstTy = dst.getType().cast(); + auto dstShape = dstTy.getShape(); assert(srcShape.size() == 2 && "Unexpected rank of ConvertLayout(blocked->shared)"); auto srcBlockedLayout = srcTy.getEncoding().cast(); @@ -2309,6 +2402,8 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( Value smemBase = getSharedMemoryBase(loc, rewriter, dst); auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); smemBase = bitcast(smemBase, elemPtrTy); + auto smemObj = SharedMemoryObject(smemBase, dstShape, loc, rewriter); + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); unsigned numWordsEachRep = product(wordsInEachRep); SmallVector wordVecs(numWordsEachRep); // TODO: We should get less barriers if it is handled by membar pass @@ -2369,8 +2464,10 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( } } } - barrier(); - rewriter.replaceOp(op, smemBase); + // Barrier is not necessary. + // The membar pass knows that it writes to shared memory and will handle it + // properly. + rewriter.replaceOp(op, retVal); return success(); } @@ -2380,9 +2477,10 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( class MMA16816SmemLoader { public: MMA16816SmemLoader(int wpt, ArrayRef order, uint32_t kOrder, - ArrayRef tileShape, ArrayRef instrShape, - ArrayRef matShape, int perPhase, int maxPhase, - int elemBytes, ConversionPatternRewriter &rewriter, + ArrayRef smemStrides, ArrayRef tileShape, + ArrayRef instrShape, ArrayRef matShape, + int perPhase, int maxPhase, int elemBytes, + ConversionPatternRewriter &rewriter, TypeConverter *typeConverter, const Location &loc) : order(order.begin(), order.end()), kOrder(kOrder), tileShape(tileShape.begin(), tileShape.end()), @@ -2393,8 +2491,8 @@ public: cMatShape = matShape[order[0]]; sMatShape = matShape[order[1]]; - cTileStride = tileShape[order[1]]; - sTileStride = tileShape[order[0]]; + cTileStride = smemStrides[order[0]]; + sTileStride = smemStrides[order[1]]; // rule: k must be the fast-changing axis. needTrans = kOrder != order[0]; @@ -2497,8 +2595,7 @@ public: for (int i = 0; i < numPtr; ++i) { Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat)); cMatOffI = xor_(cMatOffI, phase); - offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), - mul(sOff, i32_val(sTileStride))); + offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sTileStride)); } return offs; @@ -2534,7 +2631,7 @@ public: Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape))); cOff = urem(cOff, i32_val(tileShape[order[0]])); sOff = urem(sOff, i32_val(tileShape[order[1]])); - offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, i32_val(sTileStride))); + offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sTileStride)); } } return offs; @@ -2574,7 +2671,7 @@ public: // To prevent out-of-bound access when tile is too small. cOff = urem(cOff, i32_val(tileShape[order[0]])); sOff = urem(sOff, i32_val(tileShape[order[1]])); - offs[ptrOff] = add(cOff, mul(sOff, i32_val(sTileStride))); + offs[ptrOff] = add(cOff, mul(sOff, sTileStride)); } } } @@ -2608,14 +2705,15 @@ public: Value ptr = getPtr(ptrIdx); if (canUseLdmatrix) { - int sOffset = - matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes; + Value sOffset = + mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sTileStride); + Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset); PTXBuilder builder; // ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a // thread. auto resArgs = builder.newListOperand(4, "=r"); - auto addrArg = builder.newAddrOperand(ptr, "r", sOffset); + auto addrArg = builder.newAddrOperand(sOffsetPtr, "r"); auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4") ->o("trans", needTrans /*predicate*/) @@ -2640,26 +2738,24 @@ public: needTrans) { // Use lds.32 to load tf32 matrices Value ptr2 = getPtr(ptrIdx + 1); assert(sMatStride == 1); - int sOffsetElem = - matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride; - int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride; + int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape); + Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride); + int sOffsetArrElem = sMatStride * sMatShape; + Value sOffsetArrElemVal = + add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride)); Value elems[4]; Type elemTy = type::f32Ty(ctx); if (kOrder == 1) { - elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem))); - elems[1] = load(gep(elemTy, ptr2, i32_val(sOffsetElem))); - elems[2] = - load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); - elems[3] = - load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); + elems[0] = load(gep(elemTy, ptr, sOffsetElemVal)); + elems[1] = load(gep(elemTy, ptr2, sOffsetElemVal)); + elems[2] = load(gep(elemTy, ptr, sOffsetArrElemVal)); + elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal)); } else { - elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem))); - elems[2] = load(gep(elemTy, ptr2, i32_val(sOffsetElem))); - elems[1] = - load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); - elems[3] = - load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); + elems[0] = load(gep(elemTy, ptr, sOffsetElemVal)); + elems[2] = load(gep(elemTy, ptr2, sOffsetElemVal)); + elems[1] = load(gep(elemTy, ptr, sOffsetArrElemVal)); + elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal)); } return {elems[0], elems[1], elems[2], elems[3]}; @@ -2680,9 +2776,11 @@ public: }; assert(sMatStride == 1); - int sOffsetElem = - matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride; - int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride; + int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape); + Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride); + int sOffsetArrElem = 1 * (sMatStride * sMatShape); + Value sOffsetArrElemVal = + add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride)); std::array i8v4Elems; std::array i32Elems; @@ -2692,16 +2790,14 @@ public: Value i8Elems[4][4]; Type elemTy = type::i8Ty(ctx); if (kOrder == 1) { - Value offset = i32_val(sOffsetElem); - for (int i = 0; i < 2; ++i) for (int j = 0; j < 4; ++j) - i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], offset)); + i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], sOffsetElemVal)); - offset = i32_val(sOffsetElem + sOffsetArrElem); for (int i = 2; i < 4; ++i) for (int j = 0; j < 4; ++j) - i8Elems[i][j] = load(gep(elemTy, ptrs[i - 2][j], offset)); + i8Elems[i][j] = + load(gep(elemTy, ptrs[i - 2][j], sOffsetArrElemVal)); for (int m = 0; m < 4; ++m) { for (int e = 0; e < 4; ++e) @@ -2710,16 +2806,14 @@ public: i32Elems[m] = bitcast(i8v4Elems[m], i32_ty); } } else { // k first - Value offset = i32_val(sOffsetElem); for (int j = 0; j < 4; ++j) - i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], offset)); + i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], sOffsetElemVal)); for (int j = 0; j < 4; ++j) - i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], offset)); - offset = i32_val(sOffsetElem + sOffsetArrElem); + i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], sOffsetElemVal)); for (int j = 0; j < 4; ++j) - i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], offset)); + i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], sOffsetArrElemVal)); for (int j = 0; j < 4; ++j) - i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], offset)); + i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], sOffsetArrElemVal)); for (int m = 0; m < 4; ++m) { for (int e = 0; e < 4; ++e) @@ -2752,8 +2846,8 @@ private: int cMatShape; int sMatShape; - int cTileStride; - int sTileStride; + Value cTileStride; + Value sTileStride; bool needTrans; bool canUseLdmatrix; @@ -2922,12 +3016,12 @@ struct DotOpMmaV1ConversionHelper { } // Loading $a from smem to registers, returns a LLVM::Struct. - Value loadA(Value A, Value llA, Value thread, Value smem, Location loc, - ConversionPatternRewriter &rewriter) const; + Value loadA(Value A, const SharedMemoryObject &smemObj, Value thread, + Location loc, ConversionPatternRewriter &rewriter) const; // Loading $b from smem to registers, returns a LLVM::Struct. - Value loadB(Value B, Value llB, Value thread, Value smem, Location loc, - ConversionPatternRewriter &rewriter) const; + Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread, + Location loc, ConversionPatternRewriter &rewriter) const; // Loading $c to registers, returns a LLVM::Struct. Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const; @@ -3334,7 +3428,7 @@ struct MMA16816ConversionHelper { } // Loading $a from smem to registers, returns a LLVM::Struct. - Value loadA(Value tensor, Value llTensor) const { + Value loadA(Value tensor, const SharedMemoryObject &smemObj) const { auto aTensorTy = tensor.getType().cast(); auto shape = aTensorTy.getShape(); @@ -3348,7 +3442,7 @@ struct MMA16816ConversionHelper { if (aTensorTy.getEncoding().isa()) { // load from smem loadFn = getLoadMatrixFn( - tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, + tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/, {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/); } else if (aTensorTy.getEncoding().isa()) { @@ -3370,7 +3464,7 @@ struct MMA16816ConversionHelper { } // Loading $b from smem to registers, returns a LLVM::Struct. - Value loadB(Value tensor, Value llTensor) { + Value loadB(Value tensor, const SharedMemoryObject &smemObj) { ValueTable hb; auto tensorTy = tensor.getType().cast(); auto shape = tensorTy.getShape(); @@ -3380,7 +3474,7 @@ struct MMA16816ConversionHelper { int numRepN = getNumRepN(tensorTy, shape[1]); auto loadFn = getLoadMatrixFn( - tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, + tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/, {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); @@ -3485,10 +3579,10 @@ struct MMA16816ConversionHelper { private: std::function - getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout, - int wpt, uint32_t kOrder, ArrayRef instrShape, - ArrayRef matShape, Value warpId, - ValueTable &vals) const { + getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj, + MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder, + ArrayRef instrShape, ArrayRef matShape, + Value warpId, ValueTable &vals) const { auto tensorTy = tensor.getType().cast(); // We assumes that the input operand of Dot should be from shared layout. // TODO(Superjomn) Consider other layouts if needed later. @@ -3507,10 +3601,10 @@ private: // (a, b) is the coordinate. auto load = [=, &vals, &ld2](int a, int b) { - MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder, - tensorTy.getShape() /*tileShape*/, instrShape, - matShape, perPhase, maxPhase, elemBytes, - rewriter, typeConverter, loc); + MMA16816SmemLoader loader( + wpt, sharedLayout.getOrder(), kOrder, smemObj.strides, + tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase, + maxPhase, elemBytes, rewriter, typeConverter, loc); SmallVector offs = loader.computeOffsets(warpId, lane); const int numPtrs = loader.getNumPtr(); @@ -3519,8 +3613,8 @@ private: Type smemPtrTy = helper.getShemPtrTy(); for (int i = 0; i < numPtrs; ++i) { - ptrs[i] = - bitcast(gep(smemPtrTy, llTensor, ValueRange({offs[i]})), smemPtrTy); + ptrs[i] = bitcast(gep(smemPtrTy, smemObj.base, ValueRange({offs[i]})), + smemPtrTy); } auto [ha0, ha1, ha2, ha3] = loader.loadX4( @@ -3612,6 +3706,7 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( dotOperandLayout.getParent().dyn_cast_or_null(); assert(mmaLayout); + auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter); Value res; if (mmaLayout.getVersion() == 2) { MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc), @@ -3620,21 +3715,21 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( if (dotOperandLayout.getOpIdx() == 0) { // operand $a - res = mmaHelper.loadA(src, adaptor.src()); + res = mmaHelper.loadA(src, smemObj); } else if (dotOperandLayout.getOpIdx() == 1) { // operand $b - res = mmaHelper.loadB(src, adaptor.src()); + res = mmaHelper.loadB(src, smemObj); } } else if (mmaLayout.getVersion() == 1) { DotOpMmaV1ConversionHelper helper(mmaLayout); if (dotOperandLayout.getOpIdx() == 0) { // operand $a - res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc), - adaptor.src(), loc, rewriter); + res = + helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc, rewriter); } else if (dotOperandLayout.getOpIdx() == 1) { // operand $b - res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc), - adaptor.src(), loc, rewriter); + res = + helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc, rewriter); } } else { assert(false && "Unsupported mma layout found"); @@ -3671,8 +3766,12 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor, loadedA = adaptor.a(); loadedB = adaptor.b(); } else { - loadedA = mmaHelper.loadA(op.a(), adaptor.a()); - loadedB = mmaHelper.loadB(op.b(), adaptor.b()); + SharedMemoryObject smemA = + getSharedMemoryObjectFromStruct(loc, adaptor.a(), rewriter); + SharedMemoryObject smemB = + getSharedMemoryObjectFromStruct(loc, adaptor.b(), rewriter); + loadedA = mmaHelper.loadA(op.a(), smemA); + loadedB = mmaHelper.loadB(op.b(), smemB); } loadedC = mmaHelper.loadC(op.c(), adaptor.c()); @@ -3797,8 +3896,12 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, } Value DotOpMmaV1ConversionHelper::loadA( - Value tensor, Value llTensor, Value thread, Value smem, Location loc, + Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc, ConversionPatternRewriter &rewriter) const { + // smem + Value smem = smemObj.base; + auto strides = smemObj.strides; + auto *ctx = rewriter.getContext(); auto tensorTy = tensor.getType().cast(); auto shape = tensorTy.getShape(); @@ -3818,10 +3921,10 @@ Value DotOpMmaV1ConversionHelper::loadA( int vecA = sharedLayout.getVec(); - int strideAM = isARow ? shape[1] : 1; - int strideAK = isARow ? 1 : shape[0]; - int strideA0 = isARow ? strideAK : strideAM; - int strideA1 = isARow ? strideAM : strideAK; + Value strideAM = isARow ? strides[0] : i32_val(1); + Value strideAK = isARow ? i32_val(1) : strides[1]; + Value strideA0 = isARow ? strideAK : strideAM; + Value strideA1 = isARow ? strideAM : strideAK; int strideRepM = wpt[0] * fpw[0] * 8; int strideRepK = 1; @@ -3847,8 +3950,7 @@ Value DotOpMmaV1ConversionHelper::loadA( offA0I = udiv(offA0I, i32_val(vecA)); offA0I = xor_(offA0I, phaseA); offA0I = xor_(offA0I, i32_val(vecA)); - offA[i] = - add(mul(offA0I, i32_val(strideA0)), mul(offA1, i32_val(strideA1))); + offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1)); } Type f16x2Ty = vec_ty(f16_ty, 2); @@ -3877,8 +3979,9 @@ Value DotOpMmaV1ConversionHelper::loadA( int stepAM = isARow ? m : m / numPtrA * numPtrA; int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k; - Value pa = gep(f16PtrTy, thePtrA, - i32_val(stepAM * strideRepM * strideAM + stepAK * strideAK)); + Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM), + mul(i32_val(stepAK), strideAK)); + Value pa = gep(f16PtrTy, thePtrA, offset); Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max(vecA / 2, 1)), 3); Value ha = load(bitcast(pa, aPtrTy)); // record lds that needs to be moved @@ -3915,8 +4018,12 @@ Value DotOpMmaV1ConversionHelper::loadA( } Value DotOpMmaV1ConversionHelper::loadB( - Value tensor, Value llTensor, Value thread, Value smem, Location loc, + Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc, ConversionPatternRewriter &rewriter) const { + // smem + Value smem = smemObj.base; + auto strides = smemObj.strides; + auto *ctx = rewriter.getContext(); auto tensorTy = tensor.getType().cast(); auto shape = tensorTy.getShape(); @@ -3929,10 +4036,10 @@ Value DotOpMmaV1ConversionHelper::loadB( SmallVector rep({0, 2 * packSize1, 1}); // pad M with 0 SmallVector spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0 int vecB = sharedLayout.getVec(); - int strideBN = isBRow ? 1 : shape[0]; - int strideBK = isBRow ? shape[1] : 1; - int strideB0 = isBRow ? strideBN : strideBK; - int strideB1 = isBRow ? strideBK : strideBN; + Value strideBN = isBRow ? i32_val(1) : strides[1]; + Value strideBK = isBRow ? strides[0] : i32_val(1); + Value strideB0 = isBRow ? strideBN : strideBK; + Value strideB1 = isBRow ? strideBK : strideBN; int strideRepN = wpt[1] * fpw[1] * 8; int strideRepK = 1; @@ -3957,8 +4064,7 @@ Value DotOpMmaV1ConversionHelper::loadB( offB0I = udiv(offB0I, i32_val(vecB)); offB0I = xor_(offB0I, phaseB); offB0I = mul(offB0I, i32_val(vecB)); - offB[i] = - add(mul(offB0I, i32_val(strideB0)), mul(offB1, i32_val(strideB1))); + offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1)); } Type f16PtrTy = ptr_ty(f16_ty); @@ -3979,8 +4085,9 @@ Value DotOpMmaV1ConversionHelper::loadB( int stepBN = isBRow ? n / numPtrB * numPtrB : n; int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB); - Value pb = gep(f16PtrTy, thePtrB, - i32_val(stepBN * strideRepN * strideBN + stepBK * strideBK)); + Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN), + mul(i32_val(stepBK), strideBK)); + Value pb = gep(f16PtrTy, thePtrB, offset); Value hb = load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3))); // record lds that needs to be moved @@ -4171,7 +4278,17 @@ public: return LLVM::LLVMStructType::getLiteral(ctx, types); } else if (auto shared_layout = layout.dyn_cast_or_null()) { - return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3); + SmallVector types; + // base ptr + auto ptrType = + LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3); + types.push_back(ptrType); + // shape dims + auto rank = type.getRank(); + for (auto i = 0; i < rank; i++) { + types.push_back(IntegerType::get(ctx, 32)); + } + return LLVM::LLVMStructType::getLiteral(ctx, types); } else if (auto mmaLayout = layout.dyn_cast_or_null()) { if (mmaLayout.getVersion() == 2) { auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type); @@ -4309,15 +4426,26 @@ struct InsertSliceAsyncOpConversion auto srcElems = getLLVMElems(src, llSrc, rewriter, loc); // %dst + auto dstTy = dst.getType().cast(); + auto dstShape = dstTy.getShape(); + auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter); auto axis = op->getAttrOfType("axis").getInt(); - assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now"); - auto dstBase = createIndexAttrConstant(rewriter, loc, - getTypeConverter()->getIndexType(), - product(srcTy.getShape())); - Value offset = mul(llIndex, dstBase); - auto dstPtrTy = LLVM::LLVMPointerType::get( - getTypeConverter()->convertType(resTy.getElementType()), 3); - Value dstPtrBase = gep(dstPtrTy, llDst, offset); + SmallVector offsetVals; + SmallVector srcStrides; + for (auto i = 0; i < dstShape.size(); ++i) { + if (i == axis) { + offsetVals.emplace_back(llIndex); + } else { + offsetVals.emplace_back(i32_val(0)); + srcStrides.emplace_back(smemObj.strides[i]); + } + } + // Compute the offset based on the original dimensions of the shared memory + // object + auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); + auto dstPtrTy = + ptr_ty(getTypeConverter()->convertType(resTy.getElementType()), 3); + Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset); // %mask SmallVector maskElems; @@ -4345,11 +4473,10 @@ struct InsertSliceAsyncOpConversion unsigned maxPhase = resSharedLayout.getMaxPhase(); auto sizePerThread = srcBlockedLayout.getSizePerThread(); auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout); - auto inOrder = srcBlockedLayout.getOrder(); // If perPhase * maxPhase > threadsPerCTA, we need to swizzle over - // elements across phases. If perPhase * maxPhase == threadsPerCTA, + // elements across phases. If perPhase * maxPhase <= threadsPerCTA, // swizzle is not allowd auto numSwizzleRows = std::max( (perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1); @@ -4377,7 +4504,6 @@ struct InsertSliceAsyncOpConversion vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]]; auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows * threadsPerCTA[inOrder[1]]; - auto baseOffset = (baseOffsetRow * srcShape[inOrder[0]] + baseOffsetCol); auto tileVecIdxCol = vecIdxCol % numVecCols; auto tileVecIdxRow = vecIdxRow % numSwizzleRows; @@ -4399,8 +4525,10 @@ struct InsertSliceAsyncOpConversion auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]]; Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)), i32_val(maxPhase)); - Value rowOffset = - mul(srcIdx[inOrder[1]], i32_val(srcShape[inOrder[0]])); + // srcShape and smemObj.shape maybe different if smemObj is a + // slice of the original shared memory object. + // So we need to use the original shape to compute the offset + Value rowOffset = mul(srcIdx[inOrder[1]], srcStrides[inOrder[1]]); Value colOffset = add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec)); Value swizzleIdx = udiv(colOffset, i32_val(outVec)); @@ -4420,21 +4548,25 @@ struct InsertSliceAsyncOpConversion auto numWords = vecBitWidth / bitWidth; auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth(); - // XXX(Keren): Tune CG and CA here. + // Tune CG and CA here. auto byteWidth = bitWidth / 8; CacheModifier srcCacheModifier = byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA; assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4); auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8; - auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}]; + Value tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}]; + Value baseOffset = + add(mul(i32_val(baseOffsetRow), srcStrides[inOrder[1]]), + i32_val(baseOffsetCol)); + Value basePtr = gep(dstPtrTy, tileOffset, baseOffset); for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) { PTXBuilder ptxBuilder; auto wordElemIdx = wordIdx * numWordElems; auto ©AsyncOp = *ptxBuilder.create(srcCacheModifier); - auto *dstOperand = ptxBuilder.newAddrOperand( - tileOffset, "r", (wordElemIdx + baseOffset) * resByteWidth); + auto *dstOperand = + ptxBuilder.newAddrOperand(basePtr, "r", wordElemIdx * resByteWidth); auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l"); auto *copySize = ptxBuilder.newConstantOperand(byteWidth); diff --git a/lib/Dialect/Triton/IR/Traits.cpp b/lib/Dialect/Triton/IR/Traits.cpp index c3ce7e9d4..e457a8564 100644 --- a/lib/Dialect/Triton/IR/Traits.cpp +++ b/lib/Dialect/Triton/IR/Traits.cpp @@ -66,4 +66,4 @@ mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) { } } return success(); -} \ No newline at end of file +} diff --git a/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/lib/Dialect/TritonGPU/IR/CMakeLists.txt index 7df435006..903dfc318 100644 --- a/lib/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(TritonGPUIR Dialect.cpp + Traits.cpp DEPENDS TritonGPUTableGen diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 96b0925bc..c7968c76c 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -474,7 +474,7 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const { ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser, OperationState &result) { - SmallVector allOperands; + SmallVector allOperands; Type srcType, dstType; SMLoc allOperandLoc = parser.getCurrentLocation(); if (parser.parseOperandList(allOperands) || @@ -489,14 +489,27 @@ ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser, operandTypes.push_back(dstType); // dst operandTypes.push_back( IntegerType::get(parser.getBuilder().getContext(), 32)); // index - if (allOperands.size() >= 4) + + int hasMask = 0, hasOther = 0; + if (allOperands.size() >= 4) { operandTypes.push_back(triton::getI1SameShape(srcType)); // mask - if (allOperands.size() >= 5) + hasMask = 1; + } + if (allOperands.size() >= 5) { operandTypes.push_back(triton::getPointeeType(srcType)); // other + hasOther = 1; + } if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc, result.operands)) return failure(); + + // Deduce operand_segment_sizes from the number of the operands. + auto operand_segment_sizesAttrName = + InsertSliceAsyncOp::operand_segment_sizesAttrName(result.name); + result.addAttribute( + operand_segment_sizesAttrName, + parser.getBuilder().getI32VectorAttr({1, 1, 1, hasMask, hasOther})); return success(); } @@ -504,39 +517,16 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer, InsertSliceAsyncOp insertSliceAsyncOp) { printer << " "; printer << insertSliceAsyncOp.getOperation()->getOperands(); - printer.printOptionalAttrDict(insertSliceAsyncOp->getAttrs(), - /*elidedAttrs=*/{}); + // "operand_segment_sizes" can be deduced, so we don't print it. + printer.printOptionalAttrDict( + insertSliceAsyncOp->getAttrs(), + {insertSliceAsyncOp.operand_segment_sizesAttrName()}); printer << " : "; printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType()); printer << " -> "; printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType()); } -//===----------------------------------------------------------------------===// -// ExtractSliceOp -//===----------------------------------------------------------------------===// - -mlir::LogicalResult ExtractSliceOp::inferReturnTypes( - ::mlir::MLIRContext *context, llvm::Optional<::mlir::Location> location, - ::mlir::ValueRange operands, mlir::DictionaryAttr attributes, - ::mlir::RegionRange regions, - llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { - auto srcType = operands[0].getType().cast(); - auto encoding = srcType.getEncoding(); - auto srcShape = srcType.getShape(); - auto axis = attributes.get("axis").cast().getInt(); - if (axis < 0 || (size_t)axis > srcShape.size()) - return failure(); - SmallVector dstShape; - for (size_t i = 0; i < srcShape.size(); i++) - if (i != (size_t)axis) - dstShape.push_back(srcShape[i]); - auto returnType = - RankedTensorType::get(dstShape, srcType.getElementType(), encoding); - inferredReturnTypes.assign({returnType}); - return success(); -} - //===----------------------------------------------------------------------===// // DotOperand Encoding //===----------------------------------------------------------------------===// @@ -631,32 +621,6 @@ void TritonGPUDialect::initialize() { addInterfaces(); } -//===----------------------------------------------------------------------===// -// Verification -//===----------------------------------------------------------------------===// - -static LogicalResult verify(InsertSliceAsyncOp op) { - if (!isSharedEncoding(op.getResult())) { - return op.emitOpError( - "insert_slice_async should return a shared memory tensor"); - } - return success(); -} - -static LogicalResult verify(ExtractSliceOp op) { - if (!isSharedEncoding(op.getResult())) { - return op.emitOpError("extract_slice should return a shared memory tensor"); - } - return success(); -} - -static LogicalResult verify(AllocTensorOp op) { - if (!isSharedEncoding(op.getResult())) { - return op.emitOpError("alloc_tensor should return a shared memory tensor"); - } - return success(); -} - #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" diff --git a/lib/Dialect/TritonGPU/IR/Traits.cpp b/lib/Dialect/TritonGPU/IR/Traits.cpp new file mode 100644 index 000000000..03253e12c --- /dev/null +++ b/lib/Dialect/TritonGPU/IR/Traits.cpp @@ -0,0 +1,14 @@ +#include "triton/Dialect/TritonGPU/IR/Traits.h" +#include "triton/Analysis/Utility.h" + +mlir::LogicalResult +mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) { + if (failed(verifyAtLeastNResults(op, 1))) + return failure(); + + for (auto result : op->getResults()) + if (!isSharedEncoding(result)) + return op->emitOpError() << "requires all results to be shared encoding"; + + return success(); +}; diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 88988be3a..416c82671 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -111,37 +111,41 @@ public: // cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2)) auto insert_slice = dyn_cast(arg); if (insert_slice) { - auto newType = op->getResult(0).getType(); + auto newType = op->getResult(0).getType().cast(); // Ensure that the new insert_slice op is placed in the same place as the // old insert_slice op. Otherwise, the new insert_slice op may be placed // after the async_wait op, which is not allowed. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(insert_slice); - auto new_arg = rewriter.create( + auto newArg = rewriter.create( op->getLoc(), newType, insert_slice.dst()); rewriter.replaceOpWithNewOp( - op, newType, insert_slice.src(), new_arg.getResult(), + op, newType, insert_slice.src(), newArg.getResult(), insert_slice.index(), insert_slice.mask(), insert_slice.other(), - insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(), - insert_slice.axis()); + insert_slice.cache(), insert_slice.evict(), + insert_slice.isVolatile(), insert_slice.axis()); return mlir::success(); } - // cvt(extract_slice(x), type2) ->extract_slice(cvt(x, type2)) - auto extract_slice = dyn_cast(arg); + // cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2)) + auto extract_slice = dyn_cast(arg); if (extract_slice) { - auto origType = extract_slice.src().getType().cast(); + auto origType = extract_slice.source().getType().cast(); + auto newType = RankedTensorType::get( + origType.getShape(), origType.getElementType(), + op->getResult(0).getType().cast().getEncoding()); + auto resType = op->getResult(0).getType().cast(); // Ensure that the new extract_slice op is placed in the same place as the // old extract_slice op. Otherwise, the new extract_slice op may be placed // after the async_wait op, which is not allowed. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(extract_slice); - auto newType = RankedTensorType::get( - origType.getShape(), origType.getElementType(), - op->getResult(0).getType().cast().getEncoding()); - auto new_arg = rewriter.create( - op->getLoc(), newType, extract_slice.src()); - rewriter.replaceOpWithNewOp( - op, new_arg.getResult(), extract_slice.index(), extract_slice.axis()); + auto newArg = rewriter.create( + op->getLoc(), newType, extract_slice.source()); + rewriter.replaceOpWithNewOp( + op, resType, newArg.getResult(), extract_slice.offsets(), + extract_slice.sizes(), extract_slice.strides(), + extract_slice.static_offsets(), extract_slice.static_sizes(), + extract_slice.static_strides()); return mlir::success(); } // cvt(type2, x) @@ -198,7 +202,7 @@ static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op, inline bool expensive_to_remat(Operation *op) { if (!op) return true; - if (isa(op)) return true; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index abbef2efe..8dd03be9e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -339,14 +339,20 @@ void LoopPipeliner::emitPrologue() { builder.create(iv.getLoc(), 1, 32)); } // for (int stage = 0; stage < numStages - 1; ++stage) + auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); }; + // async.wait & extract_slice builder.create(loads[0].getLoc(), loads.size() * (numStages - 2)); loopIterIdx = builder.create(iv.getLoc(), 0, 32); for (Value loadOp : loads) { - Value extractSlice = builder.create( - loadOp.getLoc(), loadsMapping[loadOp].getType(), - loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0); + auto sliceType = loadsMapping[loadOp].getType().cast(); + Value extractSlice = builder.create( + loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1], + SmallVector{intAttr(0), intAttr(0), intAttr(0)}, + SmallVector{intAttr(1), intAttr(sliceType.getShape()[0]), + intAttr(sliceType.getShape()[1])}, + SmallVector{intAttr(1), intAttr(1), intAttr(1)}); loadsExtract[loadOp] = extractSlice; } // bump up loopIterIdx, this is used for getting the correct slice for the @@ -477,6 +483,10 @@ scf::ForOp LoopPipeliner::createNewForOp() { Value extractSliceIndex = builder.create( nextIV.getLoc(), loopIterIdx, builder.create(nextIV.getLoc(), numStages, 32)); + extractSliceIndex = builder.create( + extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex); + + auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); }; for (Operation *op : orderedDeps) { Operation *nextOp = nullptr; @@ -503,9 +513,14 @@ scf::ForOp LoopPipeliner::createNewForOp() { nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(), loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0); nextBuffers.push_back(insertAsyncOp); - nextOp = builder.create( - op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp, - extractSliceIndex, /*axis*/ 0); + auto sliceType = loadsMapping[loadOp].getType().cast(); + nextOp = builder.create( + op->getLoc(), sliceType, insertAsyncOp, + SmallVector{extractSliceIndex, intAttr(0), intAttr(0)}, + SmallVector{intAttr(1), + intAttr(sliceType.getShape()[0]), + intAttr(sliceType.getShape()[1])}, + SmallVector{intAttr(1), intAttr(1), intAttr(1)}); extractSlices.push_back(nextOp->getResult(0)); } else nextOp = builder.clone(*op, nextMapping); diff --git a/python/tests/test_elementwise.py b/python/tests/test_elementwise.py index f27990e74..8f0b2682f 100644 --- a/python/tests/test_elementwise.py +++ b/python/tests/test_elementwise.py @@ -137,7 +137,7 @@ def kernel(X0, X1, Y, BLOCK: tl.constexpr): # reference result if expr == "cdiv": - y_ref = (x0 + x1 - 1) // x1 + y_ref = torch.div(x0 + x1 - 1, x1, rounding_mode='trunc') elif expr == "umulhi": y_ref = ((x0.to(torch.int64) * x1) >> 32).to(torch.int32) else: diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 38aeb70a4..b5b7e9cdb 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -25,6 +25,7 @@ from filelock import FileLock import triton import triton._C.libtriton.triton as _triton +from .tools.disasm import extract def str_to_ty(name): @@ -875,8 +876,6 @@ def ttir_to_ttgir(mod, num_warps, num_stages): pm = _triton.ir.pass_manager(mod.context) pm.add_convert_triton_to_tritongpu_pass(num_warps) pm.enable_debug() - # Get error in backend due to wrong conversion in expanding async-related instruction. - # TODO[Superjomn]: Open it when fixed. pm.add_tritongpu_pipeline_pass(num_stages) pm.add_canonicalizer_pass() pm.add_cse_pass() @@ -1396,6 +1395,19 @@ class CompiledKernel: self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args) return + def get_sass(self, fun=None): + if 'sass' in self.asm: + return self.asm['sass'] + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(self.asm['cubin']) + self.sass = extract(path, fun) + finally: + os.remove(path) + self.asm['sass'] = self.sass + return self.sass + class CudaUtils(object): diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index d3873e127..fc6e1e289 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -18,10 +18,10 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { - %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> + %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> // CHECK: %4 -> %4 %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> - %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL> + %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> // CHECK-NEXT: %6 -> %6 %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -60,7 +60,7 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { %tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> %index = arith.constant 0 : i32 // CHECK: %2 -> %cst_0 - %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A> + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A> return } @@ -68,9 +68,9 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { func @extract_slice(%A : !tt.ptr) { // CHECK: %cst -> %cst %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> - %index = arith.constant 0 : i32 + %index = arith.constant 0 : index // CHECK-NEXT: %0 -> %cst - %cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A> + %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A> to tensor<16x16xf16, #A> return } @@ -144,9 +144,9 @@ func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-NEXT: %0#2 -> %cst,%cst_0 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { scf.if %i1 { - %index = arith.constant 8 : i32 + %index = arith.constant 8 : index // CHECK-NEXT: %1 -> %cst,%cst_0 - %cst0 = triton_gpu.extract_slice %a_shared, %index { axis = 0 : i32 } : tensor<128x32xf16, #A> -> tensor<32xf16, #A> + %cst0 = tensor.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A> to tensor<32xf16, #A> scf.yield } scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 9c7e7fc66..85e50005c 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -178,7 +178,7 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 512 %tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> %index = arith.constant 0 : i32 - %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A> + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A> return // CHECK-NEXT: size = 512 } @@ -187,8 +187,8 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { func @extract_slice(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> - %index = arith.constant 0 : i32 - %cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A> + %index = arith.constant 0 : index + %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1,1,1] : tensor<1x16x16xf16, #A> to tensor<16x16xf16, #A> return // CHECK-NEXT: size = 512 } @@ -271,8 +271,8 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, % %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { scf.if %i1 { - %index = arith.constant 8 : i32 - %cst0 = triton_gpu.extract_slice %a_shared, %index { axis = 0 : i32 } : tensor<128x32xf16, #A> -> tensor<32xf16, #A> + %index = arith.constant 8 : index + %cst0 = tensor.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A> to tensor<32xf16, #A> scf.yield } scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index a1c3eab76..14d3844a4 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -22,9 +22,9 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { - %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> + %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> - %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL> + %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> // CHECK: Membar 13 %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -41,7 +41,7 @@ func @raw_single_block(%A : !tt.ptr) { %cst1 = arith.constant dense : tensor<128x32xi1, #AL> %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> - %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> + %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> // CHECK: Membar 5 %a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A> @@ -53,7 +53,7 @@ func @war_single_block(%A : !tt.ptr) { %cst1 = arith.constant dense : tensor<128x32xi1, #AL> %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> - %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> + %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> // CHECK: Membar 5 %a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #AL> @@ -98,8 +98,8 @@ func @alloc() { // CHECK-LABEL: extract_slice func @extract_slice() { %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> - %index = arith.constant 0 : i32 - %cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A> + %index = arith.constant 0 : index + %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A> to tensor<16x16xf16, #A> // CHECK: Membar 3 %cst2 = triton_gpu.convert_layout %cst1 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL> // CHECK-NEXT: Membar 5 @@ -114,7 +114,7 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> %tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A> %index = arith.constant 0 : i32 - %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A> + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A> %b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A>, tensor<1x16x16xf16, #A>) -> tensor<2x16x16xf16, #A> // CHECK: Membar 7 %c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A>, tensor<2x16x16xf16, #A>) -> tensor<4x16x16xf16, #A> diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 807bd1396..59ec5a927 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -346,18 +346,24 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: basic_extract_slice func @basic_extract_slice() { - // CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem - // CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast %[[BASE0]] - // CHECK-NEXT: %[[OFFSET0:.*]] = llvm.mlir.constant - // CHECK-NEXT: %[[OFFSET1:.*]] = llvm.mlir.constant - // CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET1]]] - // CHECK-NEXT: %[[BASE2:.*]] = llvm.bitcast - // CHECK-NEXT: %[[OFFSET2:.*]] = llvm.mlir.constant - // CHECK-NEXT: %[[OFFSET3:.*]] = llvm.mul %[[OFFSET0]], %[[OFFSET2]] - // CHECK-NEXT: llvm.getelementptr %[[BASE2]][%[[OFFSET3]]] - %index = arith.constant 1 : i32 + // CHECK: llvm.mlir.addressof @global_smem + // CHECK: llvm.extractvalue + // CHECK-NEXT: llvm.extractvalue + // CHECK-NEXT: llvm.extractvalue + // CHECK-NEXT: llvm.extractvalue + // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: llvm.mul + // CHECK-NEXT: llvm.add + // CHECK-NEXT: llvm.mul + // CHECK-NEXT: llvm.add + // CHECK-NEXT: llvm.mul + // CHECK-NEXT: llvm.add + // CHECK-NEXT: llvm.getelementptr + %index = arith.constant 1 : index %0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0> - %1 = triton_gpu.extract_slice %0, %index {axis = 0: i32} : tensor<128x16x32xf32, #shared0> -> tensor<16x32xf32, #shared0> + %1 = tensor.extract_slice %0[%index, 0, 0][1, 16, 32][1, 1, 1] : tensor<128x16x32xf32, #shared0> to tensor<16x32xf32, #shared0> return } } @@ -488,22 +494,38 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %tensor = triton_gpu.alloc_tensor : tensor<2x32x32xf32, #A> %index = arith.constant 1 : i32 + // CHECK: llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.add // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.add // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.add // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.add // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.mlir.constant(16 : i32) : i32 + // CHECK: llvm.add // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.mlir.constant(16 : i32) : i32 + // CHECK: llvm.add // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.mlir.constant(16 : i32) : i32 + // CHECK: llvm.add // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.mlir.constant(16 : i32) : i32 + // CHECK: llvm.add // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr, #AL> -> tensor<2x32x32xf32, #A> diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 64025d1bf..74c4067b6 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -62,7 +62,7 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { // CHECK-LABEL: transpose func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { // CHECK-NOT: triton_gpu.convert_layout - // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> + // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> // CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]> // CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]> // CHECK: return @@ -91,7 +91,7 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt %19 = triton_gpu.convert_layout %10 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked3> %20 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3> %21 = triton_gpu.convert_layout %cst : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3> - %22 = tt.load %19, %20, %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, isOtherUnspecified = false} : tensor<64x64xf32, #blocked3> + %22 = tt.load %19, %20, %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3> %23 = triton_gpu.convert_layout %22 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1> %24 = triton_gpu.convert_layout %18 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked4> %25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4> @@ -133,7 +133,7 @@ func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %ar %23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked3> %24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3> %25 = triton_gpu.convert_layout %cst_1 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3> - %26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, #blocked3> + %26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked3> %27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1> %28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1> %29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr, #blocked1> diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index a1d333cb6..731916e8f 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -20,17 +20,18 @@ // CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]] +// CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0] +// CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]] // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} // CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index // CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]], %[[EXTRACT_IDX]] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]], %[[EXTRACT_IDX]] +// CHECK: %[[NEXT_A:.*]] = tensor.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] @@ -76,17 +77,18 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B // CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]] +// CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0] +// CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]] // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} // CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index // CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]], %[[EXTRACT_IDX]] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]], %[[EXTRACT_IDX]] +// CHECK: %[[NEXT_A:.*]] = tensor.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] @@ -130,14 +132,15 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr