[Triton-MLIR][BACKEND] Refine dot conversion (#710)

This PR does

1. Refine the dot conversion
2. some other tiny code refinement
This commit is contained in:
Yan Chunwei
2022-09-27 14:38:34 +08:00
committed by GitHub
parent 61b61755e5
commit 3a84278530
11 changed files with 439 additions and 291 deletions

View File

@@ -20,6 +20,18 @@ template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
// output[i] = input[order[i]]
template <typename T>
SmallVector<T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
size_t rank = order.size();
assert(input.size() == rank);
SmallVector<T> result(rank);
for (auto it : llvm::enumerate(order)) {
result[it.index()] = input[it.value()];
}
return result;
}
} // namespace mlir
#endif // TRITON_ANALYSIS_UTILITY_H

View File

@@ -8,6 +8,9 @@
#include <string>
namespace mlir {
class ConversionPatternRewriter;
class Location;
namespace triton {
using llvm::StringRef;
@@ -104,6 +107,31 @@ struct PTXBuilder {
// Create a list of operands.
Operand *newListOperand() { return newOperand(); }
Operand *newListOperand(ArrayRef<std::pair<mlir::Value, std::string>> items) {
auto *list = newOperand();
for (auto &item : items) {
list->listAppend(newOperand(item.first, item.second));
}
return list;
}
Operand *newListOperand(unsigned count, mlir::Value val,
const std::string &constraint) {
auto *list = newOperand();
for (int i = 0; i < count; ++i) {
list->listAppend(newOperand(val, constraint));
}
return list;
}
Operand *newListOperand(unsigned count, const std::string &constraint) {
auto *list = newOperand();
for (int i = 0; i < count; ++i) {
list->listAppend(newOperand(constraint));
}
return list;
}
// Create a new operand. It will not add to operand list.
// @value: the MLIR value bind to this operand.
// @constraint: ASM operand constraint, .e.g. "=r"
@@ -131,6 +159,11 @@ struct PTXBuilder {
std::string dump() const;
mlir::Value launch(ConversionPatternRewriter &rewriter, Location loc,
Type resTy, bool hasSideEffect = true,
bool isAlignStack = false,
ArrayRef<Attribute> attrs = {}) const;
private:
Operand *newOperand() {
argArchive.emplace_back(std::make_unique<Operand>());

View File

@@ -24,7 +24,7 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
SmallVector<unsigned> getSizePerThread(Attribute layout);
unsigned getShapePerCTA(const Attribute &layout, unsigned d);
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
SmallVector<unsigned> getOrder(const Attribute &layout);