[Triton-MLIR] tt.dot
operands now must have DotOperand layout; also added prefetch pass prototype (#712)
Co-authored-by: Jokeren <kerenzhou@openai.com> Co-authored-by: Phil Tillet <phil@openai.com> Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include <numeric>
|
||||
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
@@ -26,6 +27,26 @@ namespace mlir {
|
||||
//===----------------------------------------------------------------------===//
|
||||
namespace triton {
|
||||
|
||||
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
|
||||
getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) {
|
||||
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
||||
auto srcDotLayout = srcLayout.dyn_cast<DotOperandEncodingAttr>();
|
||||
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
|
||||
auto dstDotLayout = dstLayout.dyn_cast<DotOperandEncodingAttr>();
|
||||
assert(!(srcMmaLayout && dstMmaLayout) &&
|
||||
"Unexpected mma -> mma layout conversion");
|
||||
// mma or dot layout does not have an order, so the order depends on the
|
||||
// layout of the other operand.
|
||||
auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout)
|
||||
: getOrder(srcLayout);
|
||||
auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout)
|
||||
: getOrder(dstLayout);
|
||||
|
||||
return {inOrd, outOrd};
|
||||
}
|
||||
|
||||
SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec) {
|
||||
@@ -35,16 +56,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
Attribute dstLayout = dstTy.getEncoding();
|
||||
assert(srcLayout && dstLayout &&
|
||||
"Unexpect layout in getScratchConfigForCvtLayout()");
|
||||
unsigned rank = dstTy.getRank();
|
||||
SmallVector<unsigned> paddedRepShape(rank);
|
||||
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
||||
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
|
||||
assert(!(srcMmaLayout && dstMmaLayout) &&
|
||||
"Unexpected mma -> mma layout conversion");
|
||||
auto inOrd = srcMmaLayout ? getOrder(dstLayout) : getOrder(srcLayout);
|
||||
auto outOrd = dstMmaLayout ? getOrder(srcLayout) : getOrder(dstLayout);
|
||||
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
|
||||
unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]];
|
||||
unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]];
|
||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||
@@ -55,6 +67,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
auto srcShapePerCTA = getShapePerCTA(srcLayout);
|
||||
auto dstShapePerCTA = getShapePerCTA(dstLayout);
|
||||
|
||||
unsigned rank = dstTy.getRank();
|
||||
SmallVector<unsigned> paddedRepShape(rank);
|
||||
unsigned pad = std::max(inVec, outVec);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
paddedRepShape[d] =
|
||||
@@ -143,8 +157,6 @@ private:
|
||||
|
||||
/// Initializes temporary shared memory for a given operation.
|
||||
void getScratchValueSize(Operation *op) {
|
||||
// TODO(Keren): Add atomic ops
|
||||
// TODO(Keren): Add convert ops
|
||||
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
|
||||
// TODO(Keren): Reduce with index is not supported yet.
|
||||
auto value = op->getOperand(0);
|
||||
@@ -167,7 +179,7 @@ private:
|
||||
auto dstEncoding = dstTy.getEncoding();
|
||||
if (srcEncoding.isa<SharedEncodingAttr>() ||
|
||||
dstEncoding.isa<SharedEncodingAttr>()) {
|
||||
// Only blocked -> blocked conversion requires for scratch allocation
|
||||
// Conversions from/to shared memory do not need scratch memory.
|
||||
return;
|
||||
}
|
||||
// ConvertLayoutOp with both input/output non-shared_layout
|
||||
|
Reference in New Issue
Block a user