[OPTIMIZER] Fixed up order of shared layouts (#881)
This commit is contained in:
@@ -464,21 +464,6 @@ struct SharedMemoryObject {
|
||||
}
|
||||
}
|
||||
|
||||
// XXX(Keren): a special allocator for 3d tensors. It's a workaround for
|
||||
// now since we don't have a correct way to encoding 3d tensors in the
|
||||
// pipeline pass.
|
||||
SharedMemoryObject(Value base, ArrayRef<int64_t> shape, Location loc,
|
||||
ConversionPatternRewriter &rewriter)
|
||||
: base(base) {
|
||||
auto stride = 1;
|
||||
for (auto dim : llvm::reverse(shape)) {
|
||||
strides.emplace_back(i32_val(stride));
|
||||
offsets.emplace_back(i32_val(0));
|
||||
stride *= dim;
|
||||
}
|
||||
strides = llvm::to_vector<4>(llvm::reverse(strides));
|
||||
}
|
||||
|
||||
SmallVector<Value> getElems() const {
|
||||
SmallVector<Value> elems;
|
||||
elems.push_back(base);
|
||||
@@ -2251,8 +2236,18 @@ struct AllocTensorOpConversion
|
||||
getTypeConverter()->convertType(resultTy.getElementType());
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
auto order = resultTy.getEncoding().cast<SharedEncodingAttr>().getOrder();
|
||||
// workaround for 3D tensors
|
||||
// TODO: We need to modify the pipeline pass to give a proper shared encoding to 3D tensors
|
||||
SmallVector<unsigned> newOrder;
|
||||
if (resultTy.getShape().size() == 3)
|
||||
newOrder = {1 + order[0], 1 + order[1], 0};
|
||||
else
|
||||
newOrder = SmallVector<unsigned>(order.begin(), order.end());
|
||||
|
||||
|
||||
auto smemObj =
|
||||
SharedMemoryObject(smemBase, resultTy.getShape(), loc, rewriter);
|
||||
SharedMemoryObject(smemBase, resultTy.getShape(), newOrder, loc, rewriter);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
@@ -2302,6 +2297,10 @@ struct ExtractSliceOpConversion
|
||||
strideVals.emplace_back(smemObj.strides[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// llvm::outs() << "extract slice\n";
|
||||
// llvm::outs() << strideVals[0] << " " << smemObj.strides[1] << "\n";
|
||||
// llvm::outs() << strideVals[1] << " " << smemObj.strides[2] << "\n";
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
auto resTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
@@ -3262,8 +3261,8 @@ public:
|
||||
cMatShape = matShape[order[0]];
|
||||
sMatShape = matShape[order[1]];
|
||||
|
||||
cStride = smemStrides[1];
|
||||
sStride = smemStrides[0];
|
||||
cStride = smemStrides[order[0]];
|
||||
sStride = smemStrides[order[1]];
|
||||
|
||||
// rule: k must be the fast-changing axis.
|
||||
needTrans = kOrder != order[0];
|
||||
@@ -6202,6 +6201,7 @@ private:
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(mod.getContext(), dstDotOp,
|
||||
srcType.getShape(),
|
||||
getOrder(srcBlocked),
|
||||
srcType.getElementType()));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
|
@@ -201,7 +201,9 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
auto sharedEnc = ttg::SharedEncodingAttr::get(
|
||||
ty.getContext(), dotOpEnc, ty.getShape(), ty.getElementType());
|
||||
ty.getContext(), dotOpEnc, ty.getShape(),
|
||||
triton::gpu::getOrder(ty.getEncoding()),
|
||||
ty.getElementType());
|
||||
loadsBufferType[loadOp] = RankedTensorType::get(
|
||||
bufferShape, ty.getElementType(), sharedEnc);
|
||||
}
|
||||
|
Reference in New Issue
Block a user