[Triton-MLIR][BACKEND] Refine dot conversion (#710)

This PR does

1. Refine the dot conversion
2. some other tiny code refinement
This commit is contained in:
Yan Chunwei
2022-09-27 14:38:34 +08:00
committed by GitHub
parent 61b61755e5
commit 3a84278530
11 changed files with 439 additions and 291 deletions

View File

@@ -72,26 +72,24 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
}
}
unsigned getShapePerCTA(const Attribute &layout, unsigned d) {
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
SmallVector<unsigned> shape;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getSizePerThread()[d] *
blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d];
for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
shape.push_back(blockedLayout.getSizePerThread()[d] *
blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d]);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet");
assert(d < 2 && "Unexpected usage of getShapePerCTA");
if (d == 0) {
return 16 * mmaLayout.getWarpsPerCTA()[0];
} else {
// d == 1
return 8 * mmaLayout.getWarpsPerCTA()[1];
}
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
return 0;
}
};
return shape;
}
SmallVector<unsigned> getOrder(const Attribute &layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
@@ -106,7 +104,7 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
assert(0 && "Unimplemented usage of getOrder");
return {};
}
};
}
} // namespace gpu
} // namespace triton
@@ -180,16 +178,17 @@ SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
assert(rank == getSizePerThread().size() &&
auto sizePerThread = getSizePerThread();
auto warpsPerCTA = getWarpsPerCTA();
auto threadsPerWarp = getThreadsPerWarp();
assert(rank == sizePerThread.size() &&
"unexpected rank in BlockedEncodingAttr::getElemsPerThread");
SmallVector<unsigned> elemsPerThreadPerDim(rank);
SmallVector<unsigned> elemsPerThread(rank);
for (size_t i = 0; i < rank; ++i) {
unsigned t =
getSizePerThread()[i] * getThreadsPerWarp()[i] * getWarpsPerCTA()[i];
elemsPerThreadPerDim[i] =
ceil<unsigned>(shape[i], t) * getSizePerThread()[i];
unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i];
elemsPerThread[i] = ceil<unsigned>(shape[i], t) * sizePerThread[i];
}
return product<unsigned>(elemsPerThreadPerDim);
return product<unsigned>(elemsPerThread);
}
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
@@ -216,11 +215,9 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
}
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
assert(rank == 2 && "Unexpected rank of mma layout");
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
return elemsCol * elemsRow;
int threads = product(getWarpsPerCTA());
int numElem = product(shape);
return numElem / threads;
}
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {