[BACKEND] add dot conversion (mma version=2) (#672)
LLVM Conversion for Dot op. Due to the lack of `convert_layout`, currently, the dot only supports the following combination of operands - `$a` in shared layout - `$b` in shared layout - `$c` in MMA layout(but only Splat-like, leaving the generic cases to `convert_layout`) This PR focus on `mma.16816` related logic support, leaving the other cases to the following PR. Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -43,6 +43,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
// blocked -> blocked
|
||||
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
||||
dstLayout.isa<BlockedEncodingAttr>()) {
|
||||
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
|
||||
@@ -65,6 +66,14 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
}
|
||||
paddedRepShape[outOrd[0]] += pad;
|
||||
}
|
||||
// blocked -> shared
|
||||
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
||||
dstLayout.isa<SharedEncodingAttr>()) {
|
||||
auto sharedLayout = dstLayout.cast<SharedEncodingAttr>();
|
||||
for (int v : dstTy.getShape())
|
||||
paddedRepShape.push_back(v);
|
||||
}
|
||||
|
||||
return paddedRepShape;
|
||||
}
|
||||
|
||||
@@ -131,9 +140,8 @@ private:
|
||||
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
|
||||
auto srcEncoding = srcTy.getEncoding();
|
||||
auto dstEncoding = dstTy.getEncoding();
|
||||
if (srcEncoding.isa<SharedEncodingAttr>() ||
|
||||
dstEncoding.isa<SharedEncodingAttr>()) {
|
||||
// Only blocked -> blocked conversion requires for scratch allocation
|
||||
if (srcEncoding.isa<SharedEncodingAttr>()) {
|
||||
// only block->block and block->shared is supported now
|
||||
return;
|
||||
}
|
||||
// ConvertLayoutOp with both input/output non-shared_layout
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -177,9 +177,9 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
}
|
||||
|
||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
// TODO:
|
||||
assert(0 && "MmaEncodingAttr::getElemsPerThread not implemented");
|
||||
return 0;
|
||||
int threads = product(getWarpsPerCTA());
|
||||
int numElem = product(shape);
|
||||
return numElem / threads;
|
||||
}
|
||||
|
||||
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
|
@@ -66,7 +66,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
|
||||
// maxntid
|
||||
if (op->hasAttr(NVVMMetadataField::MaxNTid)) {
|
||||
auto attr = op->getAttr(NVVMMetadataField::MaxNTid);
|
||||
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getSInt();
|
||||
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getInt();
|
||||
hasMetadata = true;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user