[BACKEND] add dot conversion (mma version=2) (#672)

LLVM Conversion for Dot op.

Due to the lack of `convert_layout`, currently, the dot only supports
the following combination of operands

- `$a` in shared layout
- `$b` in shared layout
- `$c` in MMA layout(but only Splat-like, leaving the generic cases to
`convert_layout`)

This PR focus on `mma.16816` related logic support, leaving the other
cases to the following PR.

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Yan Chunwei
2022-09-23 11:43:54 +08:00
committed by GitHub
parent 23f424c660
commit 922155f1d2
7 changed files with 1033 additions and 116 deletions

View File

@@ -43,6 +43,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
return 0;
}
};
// blocked -> blocked
if (srcLayout.isa<BlockedEncodingAttr>() &&
dstLayout.isa<BlockedEncodingAttr>()) {
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
@@ -65,6 +66,14 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
}
paddedRepShape[outOrd[0]] += pad;
}
// blocked -> shared
if (srcLayout.isa<BlockedEncodingAttr>() &&
dstLayout.isa<SharedEncodingAttr>()) {
auto sharedLayout = dstLayout.cast<SharedEncodingAttr>();
for (int v : dstTy.getShape())
paddedRepShape.push_back(v);
}
return paddedRepShape;
}
@@ -131,9 +140,8 @@ private:
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
auto srcEncoding = srcTy.getEncoding();
auto dstEncoding = dstTy.getEncoding();
if (srcEncoding.isa<SharedEncodingAttr>() ||
dstEncoding.isa<SharedEncodingAttr>()) {
// Only blocked -> blocked conversion requires for scratch allocation
if (srcEncoding.isa<SharedEncodingAttr>()) {
// only block->block and block->shared is supported now
return;
}
// ConvertLayoutOp with both input/output non-shared_layout

File diff suppressed because it is too large Load Diff

View File

@@ -177,9 +177,9 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
}
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
// TODO:
assert(0 && "MmaEncodingAttr::getElemsPerThread not implemented");
return 0;
int threads = product(getWarpsPerCTA());
int numElem = product(shape);
return numElem / threads;
}
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {

View File

@@ -66,7 +66,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
// maxntid
if (op->hasAttr(NVVMMetadataField::MaxNTid)) {
auto attr = op->getAttr(NVVMMetadataField::MaxNTid);
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getSInt();
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getInt();
hasMetadata = true;
}