[Triton-MLIR][BACKEND] Refine dot conversion (#710)
This PR does 1. Refine the dot conversion 2. some other tiny code refinement
This commit is contained in:
@@ -20,6 +20,18 @@ template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
|||||||
|
|
||||||
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
|
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
|
||||||
|
|
||||||
|
// output[i] = input[order[i]]
|
||||||
|
template <typename T>
|
||||||
|
SmallVector<T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
|
||||||
|
size_t rank = order.size();
|
||||||
|
assert(input.size() == rank);
|
||||||
|
SmallVector<T> result(rank);
|
||||||
|
for (auto it : llvm::enumerate(order)) {
|
||||||
|
result[it.index()] = input[it.value()];
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // TRITON_ANALYSIS_UTILITY_H
|
#endif // TRITON_ANALYSIS_UTILITY_H
|
||||||
|
@@ -8,6 +8,9 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
class ConversionPatternRewriter;
|
||||||
|
class Location;
|
||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
using llvm::StringRef;
|
using llvm::StringRef;
|
||||||
|
|
||||||
@@ -104,6 +107,31 @@ struct PTXBuilder {
|
|||||||
// Create a list of operands.
|
// Create a list of operands.
|
||||||
Operand *newListOperand() { return newOperand(); }
|
Operand *newListOperand() { return newOperand(); }
|
||||||
|
|
||||||
|
Operand *newListOperand(ArrayRef<std::pair<mlir::Value, std::string>> items) {
|
||||||
|
auto *list = newOperand();
|
||||||
|
for (auto &item : items) {
|
||||||
|
list->listAppend(newOperand(item.first, item.second));
|
||||||
|
}
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operand *newListOperand(unsigned count, mlir::Value val,
|
||||||
|
const std::string &constraint) {
|
||||||
|
auto *list = newOperand();
|
||||||
|
for (int i = 0; i < count; ++i) {
|
||||||
|
list->listAppend(newOperand(val, constraint));
|
||||||
|
}
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operand *newListOperand(unsigned count, const std::string &constraint) {
|
||||||
|
auto *list = newOperand();
|
||||||
|
for (int i = 0; i < count; ++i) {
|
||||||
|
list->listAppend(newOperand(constraint));
|
||||||
|
}
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
// Create a new operand. It will not add to operand list.
|
// Create a new operand. It will not add to operand list.
|
||||||
// @value: the MLIR value bind to this operand.
|
// @value: the MLIR value bind to this operand.
|
||||||
// @constraint: ASM operand constraint, .e.g. "=r"
|
// @constraint: ASM operand constraint, .e.g. "=r"
|
||||||
@@ -131,6 +159,11 @@ struct PTXBuilder {
|
|||||||
|
|
||||||
std::string dump() const;
|
std::string dump() const;
|
||||||
|
|
||||||
|
mlir::Value launch(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
Type resTy, bool hasSideEffect = true,
|
||||||
|
bool isAlignStack = false,
|
||||||
|
ArrayRef<Attribute> attrs = {}) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Operand *newOperand() {
|
Operand *newOperand() {
|
||||||
argArchive.emplace_back(std::make_unique<Operand>());
|
argArchive.emplace_back(std::make_unique<Operand>());
|
||||||
|
@@ -24,7 +24,7 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
|
|||||||
|
|
||||||
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
||||||
|
|
||||||
unsigned getShapePerCTA(const Attribute &layout, unsigned d);
|
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
|
||||||
|
|
||||||
SmallVector<unsigned> getOrder(const Attribute &layout);
|
SmallVector<unsigned> getOrder(const Attribute &layout);
|
||||||
|
|
||||||
|
@@ -56,11 +56,14 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
|||||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||||
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
|
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
|
||||||
|
|
||||||
|
auto srcShapePerCTA = getShapePerCTA(srcLayout);
|
||||||
|
auto dstShapePerCTA = getShapePerCTA(dstLayout);
|
||||||
|
|
||||||
unsigned pad = std::max(inVec, outVec);
|
unsigned pad = std::max(inVec, outVec);
|
||||||
for (unsigned d = 0; d < rank; ++d) {
|
for (unsigned d = 0; d < rank; ++d) {
|
||||||
paddedRepShape[d] = std::max(
|
paddedRepShape[d] =
|
||||||
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)),
|
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
|
||||||
std::min<unsigned>(dstTy.getShape()[d], getShapePerCTA(dstLayout, d)));
|
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
|
||||||
}
|
}
|
||||||
unsigned paddedDim = 1;
|
unsigned paddedDim = 1;
|
||||||
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
@@ -65,7 +65,7 @@ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
|
|||||||
DimVectorT retContiguity;
|
DimVectorT retContiguity;
|
||||||
DimVectorT retDivisibility;
|
DimVectorT retDivisibility;
|
||||||
DimVectorT retConstancy;
|
DimVectorT retConstancy;
|
||||||
for (size_t d = 0; d < lhs.getRank(); d++) {
|
for (size_t d = 0; d < lhs.getRank(); ++d) {
|
||||||
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
||||||
retDivisibility.push_back(
|
retDivisibility.push_back(
|
||||||
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
||||||
@@ -87,7 +87,7 @@ AxisInfo AxisInfoAnalysis::visitBinaryOp(
|
|||||||
AxisInfo::DimVectorT newContiguity;
|
AxisInfo::DimVectorT newContiguity;
|
||||||
AxisInfo::DimVectorT newDivisibility;
|
AxisInfo::DimVectorT newDivisibility;
|
||||||
AxisInfo::DimVectorT newConstancy;
|
AxisInfo::DimVectorT newConstancy;
|
||||||
for (size_t d = 0; d < rank; d++) {
|
for (size_t d = 0; d < rank; ++d) {
|
||||||
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||||
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||||
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
||||||
@@ -166,7 +166,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
|||||||
AxisInfo::DimVectorT contiguity;
|
AxisInfo::DimVectorT contiguity;
|
||||||
AxisInfo::DimVectorT divisibility;
|
AxisInfo::DimVectorT divisibility;
|
||||||
AxisInfo::DimVectorT constancy;
|
AxisInfo::DimVectorT constancy;
|
||||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
for (size_t d = 0; d < retTy.getRank(); ++d) {
|
||||||
contiguity.push_back(1);
|
contiguity.push_back(1);
|
||||||
divisibility.push_back(opInfo.getDivisibility(0));
|
divisibility.push_back(opInfo.getDivisibility(0));
|
||||||
constancy.push_back(retTy.getShape()[d]);
|
constancy.push_back(retTy.getShape()[d]);
|
||||||
@@ -202,7 +202,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
|||||||
AxisInfo::DimVectorT contiguity;
|
AxisInfo::DimVectorT contiguity;
|
||||||
AxisInfo::DimVectorT divisibility;
|
AxisInfo::DimVectorT divisibility;
|
||||||
AxisInfo::DimVectorT constancy;
|
AxisInfo::DimVectorT constancy;
|
||||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
for (size_t d = 0; d < retTy.getRank(); ++d) {
|
||||||
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
||||||
divisibility.push_back(opInfo.getDivisibility(d));
|
divisibility.push_back(opInfo.getDivisibility(d));
|
||||||
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
|
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
|
||||||
|
@@ -1,4 +1,6 @@
|
|||||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include <sstream> // unify to llvm::raw_string_ostream ?
|
#include <sstream> // unify to llvm::raw_string_ostream ?
|
||||||
|
|
||||||
@@ -10,7 +12,7 @@ std::string strJoin(llvm::ArrayRef<std::string> strs,
|
|||||||
llvm::StringRef delimiter) {
|
llvm::StringRef delimiter) {
|
||||||
std::string osStr;
|
std::string osStr;
|
||||||
llvm::raw_string_ostream os(osStr);
|
llvm::raw_string_ostream os(osStr);
|
||||||
for (size_t i = 0; !strs.empty() && i < strs.size() - 1; i++)
|
for (size_t i = 0; !strs.empty() && i < strs.size() - 1; ++i)
|
||||||
os << strs[i] << delimiter;
|
os << strs[i] << delimiter;
|
||||||
if (!strs.empty())
|
if (!strs.empty())
|
||||||
os << strs.back();
|
os << strs.back();
|
||||||
@@ -74,6 +76,25 @@ SmallVector<PTXBuilder::Operand *, 4> PTXBuilder::getAllArgs() const {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mlir::Value PTXBuilder::launch(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, Type resTy, bool hasSideEffect,
|
||||||
|
bool isAlignStack,
|
||||||
|
ArrayRef<Attribute> attrs) const {
|
||||||
|
auto *ctx = rewriter.getContext();
|
||||||
|
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
||||||
|
loc, resTy, getAllMLIRArgs(), // operands
|
||||||
|
dump(), // asm_string
|
||||||
|
getConstraints(), // constraints
|
||||||
|
hasSideEffect, // has_side_effects
|
||||||
|
isAlignStack, // is_align_stack
|
||||||
|
LLVM::AsmDialectAttr::get(ctx,
|
||||||
|
LLVM::AsmDialect::AD_ATT), // asm_dialect
|
||||||
|
ArrayAttr::get(ctx, attrs) // operand_attrs
|
||||||
|
);
|
||||||
|
|
||||||
|
return inlineAsm.getRes();
|
||||||
|
}
|
||||||
|
|
||||||
std::string PTXInstr::Operand::dump() const {
|
std::string PTXInstr::Operand::dump() const {
|
||||||
if (repr)
|
if (repr)
|
||||||
return repr(idx);
|
return repr(idx);
|
||||||
@@ -151,5 +172,6 @@ PTXInstrExecution::getArgList() const {
|
|||||||
}
|
}
|
||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -72,26 +72,24 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned getShapePerCTA(const Attribute &layout, unsigned d) {
|
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||||
|
SmallVector<unsigned> shape;
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
return blockedLayout.getSizePerThread()[d] *
|
for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
|
||||||
blockedLayout.getThreadsPerWarp()[d] *
|
shape.push_back(blockedLayout.getSizePerThread()[d] *
|
||||||
blockedLayout.getWarpsPerCTA()[d];
|
blockedLayout.getThreadsPerWarp()[d] *
|
||||||
|
blockedLayout.getWarpsPerCTA()[d]);
|
||||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
assert(mmaLayout.getVersion() == 2 &&
|
assert(mmaLayout.getVersion() == 2 &&
|
||||||
"mmaLayout version = 1 is not implemented yet");
|
"mmaLayout version = 1 is not implemented yet");
|
||||||
assert(d < 2 && "Unexpected usage of getShapePerCTA");
|
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||||
if (d == 0) {
|
8 * mmaLayout.getWarpsPerCTA()[1]};
|
||||||
return 16 * mmaLayout.getWarpsPerCTA()[0];
|
|
||||||
} else {
|
|
||||||
// d == 1
|
|
||||||
return 8 * mmaLayout.getWarpsPerCTA()[1];
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<unsigned> getOrder(const Attribute &layout) {
|
SmallVector<unsigned> getOrder(const Attribute &layout) {
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
@@ -106,7 +104,7 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
|
|||||||
assert(0 && "Unimplemented usage of getOrder");
|
assert(0 && "Unimplemented usage of getOrder");
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
@@ -180,16 +178,17 @@ SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
|
|||||||
|
|
||||||
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||||
size_t rank = shape.size();
|
size_t rank = shape.size();
|
||||||
assert(rank == getSizePerThread().size() &&
|
auto sizePerThread = getSizePerThread();
|
||||||
|
auto warpsPerCTA = getWarpsPerCTA();
|
||||||
|
auto threadsPerWarp = getThreadsPerWarp();
|
||||||
|
assert(rank == sizePerThread.size() &&
|
||||||
"unexpected rank in BlockedEncodingAttr::getElemsPerThread");
|
"unexpected rank in BlockedEncodingAttr::getElemsPerThread");
|
||||||
SmallVector<unsigned> elemsPerThreadPerDim(rank);
|
SmallVector<unsigned> elemsPerThread(rank);
|
||||||
for (size_t i = 0; i < rank; ++i) {
|
for (size_t i = 0; i < rank; ++i) {
|
||||||
unsigned t =
|
unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i];
|
||||||
getSizePerThread()[i] * getThreadsPerWarp()[i] * getWarpsPerCTA()[i];
|
elemsPerThread[i] = ceil<unsigned>(shape[i], t) * sizePerThread[i];
|
||||||
elemsPerThreadPerDim[i] =
|
|
||||||
ceil<unsigned>(shape[i], t) * getSizePerThread()[i];
|
|
||||||
}
|
}
|
||||||
return product<unsigned>(elemsPerThreadPerDim);
|
return product<unsigned>(elemsPerThread);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||||
@@ -216,11 +215,9 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||||
size_t rank = shape.size();
|
int threads = product(getWarpsPerCTA());
|
||||||
assert(rank == 2 && "Unexpected rank of mma layout");
|
int numElem = product(shape);
|
||||||
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
return numElem / threads;
|
||||||
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
|
||||||
return elemsCol * elemsRow;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
add_triton_ut(
|
add_triton_ut(
|
||||||
NAME TritonAnalysisTests
|
NAME TestTritonAnalysis
|
||||||
SRCS UtilityTest.cpp
|
SRCS UtilityTest.cpp
|
||||||
LIBS TritonAnalysis
|
LIBS TritonAnalysis
|
||||||
)
|
)
|
||||||
|
@@ -4,11 +4,26 @@
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "triton/Analysis/Utility.h"
|
#include "triton/Analysis/Utility.h"
|
||||||
#include <gmock/gmock.h>
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
TEST(UtilityTest, DummyTest) { EXPECT_EQ(true, true); }
|
TEST(Analysis, reorder) {
|
||||||
|
SmallVector<int> shape({10, 20, 30});
|
||||||
|
{
|
||||||
|
SmallVector<unsigned> order({2, 1, 0});
|
||||||
|
auto reordered = reorder<int>(shape, order);
|
||||||
|
EXPECT_EQ(reordered[0], 30);
|
||||||
|
EXPECT_EQ(reordered[1], 20);
|
||||||
|
EXPECT_EQ(reordered[2], 10);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
SmallVector<unsigned> order({1, 0, 2});
|
||||||
|
auto reordered = reorder<int>(shape, order);
|
||||||
|
EXPECT_EQ(reordered[0], 20);
|
||||||
|
EXPECT_EQ(reordered[1], 10);
|
||||||
|
EXPECT_EQ(reordered[2], 30);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
add_triton_ut(
|
add_triton_ut(
|
||||||
NAME PtxAsmFormatTest
|
NAME TestPtxAsmFormat
|
||||||
SRCS PtxAsmFormatTest.cpp
|
SRCS PtxAsmFormatTest.cpp
|
||||||
LIBS TritonGPUToLLVM
|
LIBS TritonGPUToLLVM
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user