[Triton-MLIR][Backend] Minor fix for allocation and backend in handling tt.ptr tensors (#878)
This commit is contained in:
@@ -27,6 +27,9 @@ namespace mlir {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
|
// Bitwidth of pointers
|
||||||
|
constexpr int kPtrBitWidth = 64;
|
||||||
|
|
||||||
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
|
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
|
||||||
getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) {
|
getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) {
|
||||||
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
||||||
@@ -193,7 +196,9 @@ private:
|
|||||||
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
|
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
|
||||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||||
std::multiplies{});
|
std::multiplies{});
|
||||||
auto bytes = elems * srcTy.getElementTypeBitWidth() / 8;
|
auto bytes = srcTy.getElementType().isa<triton::PointerType>()?
|
||||||
|
elems * kPtrBitWidth / 8 :
|
||||||
|
elems * srcTy.getElementTypeBitWidth() / 8;
|
||||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -93,6 +93,8 @@ void llPrintf(StringRef msg, ValueRange args,
|
|||||||
ConversionPatternRewriter &rewriter);
|
ConversionPatternRewriter &rewriter);
|
||||||
|
|
||||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
||||||
|
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
|
||||||
|
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
|
||||||
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
||||||
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
||||||
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
||||||
@@ -2945,8 +2947,13 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
}
|
}
|
||||||
auto elemTy = type.getElementType();
|
auto elemTy = type.getElementType();
|
||||||
bool isInt1 = elemTy.isInteger(1);
|
bool isInt1 = elemTy.isInteger(1);
|
||||||
|
bool isPtr = elemTy.isa<triton::PointerType>();
|
||||||
|
auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy);
|
||||||
if (isInt1)
|
if (isInt1)
|
||||||
elemTy = IntegerType::get(elemTy.getContext(), 8);
|
elemTy = IntegerType::get(elemTy.getContext(), 8);
|
||||||
|
else if (isPtr)
|
||||||
|
elemTy = IntegerType::get(elemTy.getContext(), 64);
|
||||||
|
|
||||||
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
|
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
|
||||||
|
|
||||||
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
||||||
@@ -2978,6 +2985,8 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
|
auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
|
||||||
if (isInt1)
|
if (isInt1)
|
||||||
currVal = zext(llvmElemTy, currVal);
|
currVal = zext(llvmElemTy, currVal);
|
||||||
|
else if (isPtr)
|
||||||
|
currVal = ptrtoint(llvmElemTy, currVal);
|
||||||
|
|
||||||
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
|
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
|
||||||
}
|
}
|
||||||
@@ -2990,6 +2999,8 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
currVal =
|
currVal =
|
||||||
icmp_ne(currVal, rewriter.create<LLVM::ConstantOp>(
|
icmp_ne(currVal, rewriter.create<LLVM::ConstantOp>(
|
||||||
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
|
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
|
||||||
|
else if (isPtr)
|
||||||
|
currVal = inttoptr(llvmElemTyOrig, currVal);
|
||||||
vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
|
vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -742,6 +742,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
@@ -753,8 +754,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
@@ -766,8 +767,25 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||||
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
|
// CHECK-LABEL: convert_blocked_to_blocked_ptr
|
||||||
|
func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) {
|
||||||
|
// CHECK: llvm.ptrtoint
|
||||||
|
// CHECK: llvm.store
|
||||||
|
// CHECK: nvvm.barrier0
|
||||||
|
// CHECK: llvm.inttoptr
|
||||||
|
// CHECK-COUNT-4: llvm.insertvalue
|
||||||
|
%cvt = triton_gpu.convert_layout %src : (tensor<32x!tt.ptr<f32>, #blocked0>) -> tensor<32x!tt.ptr<f32>, #blocked1>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||||
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
|
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
|
||||||
@@ -839,6 +857,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
// CHECK-LABEL: atomic_add_f32
|
// CHECK-LABEL: atomic_add_f32
|
||||||
|
Reference in New Issue
Block a user