[Triton-MLIR][BACKEND] Port the mma<v1> conversion (#815)
This PR does - port the mma<v1> related code, and support dot conversion and convert_layout[shared->dot_op<mma<v1>>] - add a lit test for dot v1
This commit is contained in:
@@ -270,10 +270,23 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
size_t rank = shape.size();
|
||||
assert(rank == 2 && "Unexpected rank of mma layout");
|
||||
assert(getVersion() == 2 && "mmaLayout version = 1 is not implemented yet");
|
||||
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
||||
return elemsCol * elemsRow;
|
||||
assert((getVersion() == 1 || getVersion() == 2) &&
|
||||
"Only version 1 and 2 is supported");
|
||||
|
||||
int res = 0;
|
||||
if (getVersion() == 1) {
|
||||
unsigned mmasRow = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]);
|
||||
unsigned mmasCol = ceil<unsigned>(shape[1], 16 * getWarpsPerCTA()[1]);
|
||||
// Each warp-level mma884 will perform a m16xn16xk4 mma, thus get a m16xn16
|
||||
// matrix as result.
|
||||
res = mmasRow * mmasCol * (16 * 16 / 32);
|
||||
} else if (getVersion() == 2) {
|
||||
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
||||
res = elemsCol * elemsRow;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
|
Reference in New Issue
Block a user