[OPTIMIZER] Fixed up order of shared layouts (#881)

This commit is contained in:
Philippe Tillet
2022-11-21 06:25:02 +01:00
committed by GitHub
parent 4d64ffb5fe
commit 23f71daa27
6 changed files with 27 additions and 27 deletions

View File

@@ -464,21 +464,6 @@ struct SharedMemoryObject {
}
}
// XXX(Keren): a special allocator for 3d tensors. It's a workaround for
// now since we don't have a correct way to encoding 3d tensors in the
// pipeline pass.
SharedMemoryObject(Value base, ArrayRef<int64_t> shape, Location loc,
ConversionPatternRewriter &rewriter)
: base(base) {
auto stride = 1;
for (auto dim : llvm::reverse(shape)) {
strides.emplace_back(i32_val(stride));
offsets.emplace_back(i32_val(0));
stride *= dim;
}
strides = llvm::to_vector<4>(llvm::reverse(strides));
}
SmallVector<Value> getElems() const {
SmallVector<Value> elems;
elems.push_back(base);
@@ -2251,8 +2236,18 @@ struct AllocTensorOpConversion
getTypeConverter()->convertType(resultTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto order = resultTy.getEncoding().cast<SharedEncodingAttr>().getOrder();
// workaround for 3D tensors
// TODO: We need to modify the pipeline pass to give a proper shared encoding to 3D tensors
SmallVector<unsigned> newOrder;
if (resultTy.getShape().size() == 3)
newOrder = {1 + order[0], 1 + order[1], 0};
else
newOrder = SmallVector<unsigned>(order.begin(), order.end());
auto smemObj =
SharedMemoryObject(smemBase, resultTy.getShape(), loc, rewriter);
SharedMemoryObject(smemBase, resultTy.getShape(), newOrder, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
@@ -2302,6 +2297,10 @@ struct ExtractSliceOpConversion
strideVals.emplace_back(smemObj.strides[i]);
}
}
// llvm::outs() << "extract slice\n";
// llvm::outs() << strideVals[0] << " " << smemObj.strides[1] << "\n";
// llvm::outs() << strideVals[1] << " " << smemObj.strides[2] << "\n";
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
auto resTy = op.getType().dyn_cast<RankedTensorType>();
@@ -3262,8 +3261,8 @@ public:
cMatShape = matShape[order[0]];
sMatShape = matShape[order[1]];
cStride = smemStrides[1];
sStride = smemStrides[0];
cStride = smemStrides[order[0]];
sStride = smemStrides[order[1]];
// rule: k must be the fast-changing axis.
needTrans = kOrder != order[0];
@@ -6202,6 +6201,7 @@ private:
dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(mod.getContext(), dstDotOp,
srcType.getShape(),
getOrder(srcBlocked),
srcType.getElementType()));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());

View File

@@ -201,7 +201,9 @@ LogicalResult LoopPipeliner::initialize() {
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(), ty.getElementType());
ty.getContext(), dotOpEnc, ty.getShape(),
triton::gpu::getOrder(ty.getEncoding()),
ty.getElementType());
loadsBufferType[loadOp] = RankedTensorType::get(
bufferShape, ty.getElementType(), sharedEnc);
}