[Triton-MLIR][BACKEND] Port the mma<v1> conversion (#815)

This PR does

- port the mma<v1> related code, and support dot conversion and
convert_layout[shared->dot_op<mma<v1>>]
- add a lit test for dot v1
This commit is contained in:
Yan Chunwei
2022-11-01 09:42:14 +08:00
committed by GitHub
parent cb1b87a688
commit 031c2ae77b
3 changed files with 690 additions and 92 deletions

View File

@@ -270,10 +270,23 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
assert(rank == 2 && "Unexpected rank of mma layout");
assert(getVersion() == 2 && "mmaLayout version = 1 is not implemented yet");
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
return elemsCol * elemsRow;
assert((getVersion() == 1 || getVersion() == 2) &&
"Only version 1 and 2 is supported");
int res = 0;
if (getVersion() == 1) {
unsigned mmasRow = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]);
unsigned mmasCol = ceil<unsigned>(shape[1], 16 * getWarpsPerCTA()[1]);
// Each warp-level mma884 will perform a m16xn16xk4 mma, thus get a m16xn16
// matrix as result.
res = mmasRow * mmasCol * (16 * 16 / 32);
} else if (getVersion() == 2) {
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
res = elemsCol * elemsRow;
}
return res;
}
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {