[Triton-MLIR][BACKEND] Refine dot conversion (#710)
This PR does 1. Refine the dot conversion 2. some other tiny code refinement
This commit is contained in:
@@ -20,6 +20,18 @@ template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
||||
|
||||
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
|
||||
|
||||
// output[i] = input[order[i]]
|
||||
template <typename T>
|
||||
SmallVector<T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
|
||||
size_t rank = order.size();
|
||||
assert(input.size() == rank);
|
||||
SmallVector<T> result(rank);
|
||||
for (auto it : llvm::enumerate(order)) {
|
||||
result[it.index()] = input[it.value()];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_UTILITY_H
|
||||
|
@@ -8,6 +8,9 @@
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
class ConversionPatternRewriter;
|
||||
class Location;
|
||||
|
||||
namespace triton {
|
||||
using llvm::StringRef;
|
||||
|
||||
@@ -104,6 +107,31 @@ struct PTXBuilder {
|
||||
// Create a list of operands.
|
||||
Operand *newListOperand() { return newOperand(); }
|
||||
|
||||
Operand *newListOperand(ArrayRef<std::pair<mlir::Value, std::string>> items) {
|
||||
auto *list = newOperand();
|
||||
for (auto &item : items) {
|
||||
list->listAppend(newOperand(item.first, item.second));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
Operand *newListOperand(unsigned count, mlir::Value val,
|
||||
const std::string &constraint) {
|
||||
auto *list = newOperand();
|
||||
for (int i = 0; i < count; ++i) {
|
||||
list->listAppend(newOperand(val, constraint));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
Operand *newListOperand(unsigned count, const std::string &constraint) {
|
||||
auto *list = newOperand();
|
||||
for (int i = 0; i < count; ++i) {
|
||||
list->listAppend(newOperand(constraint));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
// Create a new operand. It will not add to operand list.
|
||||
// @value: the MLIR value bind to this operand.
|
||||
// @constraint: ASM operand constraint, .e.g. "=r"
|
||||
@@ -131,6 +159,11 @@ struct PTXBuilder {
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
mlir::Value launch(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Type resTy, bool hasSideEffect = true,
|
||||
bool isAlignStack = false,
|
||||
ArrayRef<Attribute> attrs = {}) const;
|
||||
|
||||
private:
|
||||
Operand *newOperand() {
|
||||
argArchive.emplace_back(std::make_unique<Operand>());
|
||||
|
@@ -24,7 +24,7 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
||||
|
||||
unsigned getShapePerCTA(const Attribute &layout, unsigned d);
|
||||
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getOrder(const Attribute &layout);
|
||||
|
||||
|
Reference in New Issue
Block a user