[Triton-MLIR][Backend] Minor fix for allocation and backend in handling tt.ptr tensors (#878)

This commit is contained in:
goostavz
2022-11-15 18:08:07 +08:00
committed by GitHub
parent a22ff39017
commit 37f5846280
3 changed files with 37 additions and 2 deletions

View File

@@ -27,6 +27,9 @@ namespace mlir {
//===----------------------------------------------------------------------===//
namespace triton {
// Bitwidth of pointers
constexpr int kPtrBitWidth = 64;
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) {
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
@@ -193,7 +196,9 @@ private:
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto bytes = elems * srcTy.getElementTypeBitWidth() / 8;
auto bytes = srcTy.getElementType().isa<triton::PointerType>()?
elems * kPtrBitWidth / 8 :
elems * srcTy.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
}