From 37f5846280cbea1bfbef05b660d0da8e1d30ebb9 Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Tue, 15 Nov 2022 18:08:07 +0800 Subject: [PATCH] [Triton-MLIR][Backend] Minor fix for allocation and backend in handling tt.ptr tensors (#878) --- lib/Analysis/Allocation.cpp | 7 ++++++- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 11 ++++++++++ test/Conversion/tritongpu_to_llvm.mlir | 21 ++++++++++++++++++- 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 39659cbf3..97734c04c 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -27,6 +27,9 @@ namespace mlir { //===----------------------------------------------------------------------===// namespace triton { +// Bitwidth of pointers +constexpr int kPtrBitWidth = 64; + static std::pair, SmallVector> getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) { auto srcBlockedLayout = srcLayout.dyn_cast(); @@ -193,7 +196,9 @@ private: auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec); unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, std::multiplies{}); - auto bytes = elems * srcTy.getElementTypeBitWidth() / 8; + auto bytes = srcTy.getElementType().isa()? + elems * kPtrBitWidth / 8 : + elems * srcTy.getElementTypeBitWidth() / 8; allocation->addBuffer(op, bytes); } } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 90abf61c6..d8cb20ae7 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -93,6 +93,8 @@ void llPrintf(StringRef msg, ValueRange args, ConversionPatternRewriter &rewriter); // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive +#define inttoptr(...) rewriter.create(loc, __VA_ARGS__) +#define ptrtoint(...) rewriter.create(loc, __VA_ARGS__) #define zext(...) rewriter.create(loc, __VA_ARGS__) #define udiv(...) rewriter.create(loc, __VA_ARGS__) #define urem(...) rewriter.create(loc, __VA_ARGS__) @@ -2945,8 +2947,13 @@ void ConvertLayoutOpConversion::processReplica( } auto elemTy = type.getElementType(); bool isInt1 = elemTy.isInteger(1); + bool isPtr = elemTy.isa(); + auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy); if (isInt1) elemTy = IntegerType::get(elemTy.getContext(), 8); + else if (isPtr) + elemTy = IntegerType::get(elemTy.getContext(), 64); + auto llvmElemTy = getTypeConverter()->convertType(elemTy); for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { @@ -2978,6 +2985,8 @@ void ConvertLayoutOpConversion::processReplica( auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v]; if (isInt1) currVal = zext(llvmElemTy, currVal); + else if (isPtr) + currVal = ptrtoint(llvmElemTy, currVal); valVec = insert_element(vecTy, valVec, currVal, idx_val(v)); } @@ -2990,6 +2999,8 @@ void ConvertLayoutOpConversion::processReplica( currVal = icmp_ne(currVal, rewriter.create( loc, i8_ty, rewriter.getI8IntegerAttr(0))); + else if (isPtr) + currVal = inttoptr(llvmElemTyOrig, currVal); vals[elemId + linearCTAId * accumSizePerThread + v] = currVal; } } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index cba6082ec..bcb729ef0 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -742,6 +742,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { } // ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { @@ -753,8 +754,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { } } - // ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { @@ -766,8 +767,25 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { } } +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_blocked_to_blocked_ptr + func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr, #blocked0>) { + // CHECK: llvm.ptrtoint + // CHECK: llvm.store + // CHECK: nvvm.barrier0 + // CHECK: llvm.inttoptr + // CHECK-COUNT-4: llvm.insertvalue + %cvt = triton_gpu.convert_layout %src : (tensor<32x!tt.ptr, #blocked0>) -> tensor<32x!tt.ptr, #blocked1> + return + } +} // ----- + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}> @@ -839,6 +857,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { } // ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32