[Triton-MLIR][Backend] Minor fix for allocation and backend in handling tt.ptr tensors (#878)

This commit is contained in:
goostavz
2022-11-15 18:08:07 +08:00
committed by GitHub
parent a22ff39017
commit 37f5846280
3 changed files with 37 additions and 2 deletions

View File

@@ -27,6 +27,9 @@ namespace mlir {
//===----------------------------------------------------------------------===//
namespace triton {
// Bitwidth of pointers
constexpr int kPtrBitWidth = 64;
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) {
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
@@ -193,7 +196,9 @@ private:
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto bytes = elems * srcTy.getElementTypeBitWidth() / 8;
auto bytes = srcTy.getElementType().isa<triton::PointerType>()?
elems * kPtrBitWidth / 8 :
elems * srcTy.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
}

View File

@@ -93,6 +93,8 @@ void llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter);
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
@@ -2945,8 +2947,13 @@ void ConvertLayoutOpConversion::processReplica(
}
auto elemTy = type.getElementType();
bool isInt1 = elemTy.isInteger(1);
bool isPtr = elemTy.isa<triton::PointerType>();
auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy);
if (isInt1)
elemTy = IntegerType::get(elemTy.getContext(), 8);
else if (isPtr)
elemTy = IntegerType::get(elemTy.getContext(), 64);
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
@@ -2978,6 +2985,8 @@ void ConvertLayoutOpConversion::processReplica(
auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
if (isInt1)
currVal = zext(llvmElemTy, currVal);
else if (isPtr)
currVal = ptrtoint(llvmElemTy, currVal);
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
}
@@ -2990,6 +2999,8 @@ void ConvertLayoutOpConversion::processReplica(
currVal =
icmp_ne(currVal, rewriter.create<LLVM::ConstantOp>(
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
else if (isPtr)
currVal = inttoptr(llvmElemTyOrig, currVal);
vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
}
}