[Triton-MLIR][BACKEND] Refine dot conversion (#710)

This PR does

1. Refine the dot conversion
2. some other tiny code refinement
This commit is contained in:
Yan Chunwei
2022-09-27 14:38:34 +08:00
committed by GitHub
parent 61b61755e5
commit 3a84278530
11 changed files with 439 additions and 291 deletions

View File

@@ -20,6 +20,18 @@ template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; } template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
// output[i] = input[order[i]]
template <typename T>
SmallVector<T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
size_t rank = order.size();
assert(input.size() == rank);
SmallVector<T> result(rank);
for (auto it : llvm::enumerate(order)) {
result[it.index()] = input[it.value()];
}
return result;
}
} // namespace mlir } // namespace mlir
#endif // TRITON_ANALYSIS_UTILITY_H #endif // TRITON_ANALYSIS_UTILITY_H

View File

@@ -8,6 +8,9 @@
#include <string> #include <string>
namespace mlir { namespace mlir {
class ConversionPatternRewriter;
class Location;
namespace triton { namespace triton {
using llvm::StringRef; using llvm::StringRef;
@@ -104,6 +107,31 @@ struct PTXBuilder {
// Create a list of operands. // Create a list of operands.
Operand *newListOperand() { return newOperand(); } Operand *newListOperand() { return newOperand(); }
Operand *newListOperand(ArrayRef<std::pair<mlir::Value, std::string>> items) {
auto *list = newOperand();
for (auto &item : items) {
list->listAppend(newOperand(item.first, item.second));
}
return list;
}
Operand *newListOperand(unsigned count, mlir::Value val,
const std::string &constraint) {
auto *list = newOperand();
for (int i = 0; i < count; ++i) {
list->listAppend(newOperand(val, constraint));
}
return list;
}
Operand *newListOperand(unsigned count, const std::string &constraint) {
auto *list = newOperand();
for (int i = 0; i < count; ++i) {
list->listAppend(newOperand(constraint));
}
return list;
}
// Create a new operand. It will not add to operand list. // Create a new operand. It will not add to operand list.
// @value: the MLIR value bind to this operand. // @value: the MLIR value bind to this operand.
// @constraint: ASM operand constraint, .e.g. "=r" // @constraint: ASM operand constraint, .e.g. "=r"
@@ -131,6 +159,11 @@ struct PTXBuilder {
std::string dump() const; std::string dump() const;
mlir::Value launch(ConversionPatternRewriter &rewriter, Location loc,
Type resTy, bool hasSideEffect = true,
bool isAlignStack = false,
ArrayRef<Attribute> attrs = {}) const;
private: private:
Operand *newOperand() { Operand *newOperand() {
argArchive.emplace_back(std::make_unique<Operand>()); argArchive.emplace_back(std::make_unique<Operand>());

View File

@@ -24,7 +24,7 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
SmallVector<unsigned> getSizePerThread(Attribute layout); SmallVector<unsigned> getSizePerThread(Attribute layout);
unsigned getShapePerCTA(const Attribute &layout, unsigned d); SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
SmallVector<unsigned> getOrder(const Attribute &layout); SmallVector<unsigned> getOrder(const Attribute &layout);

View File

@@ -56,11 +56,14 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread; inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread; outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
auto srcShapePerCTA = getShapePerCTA(srcLayout);
auto dstShapePerCTA = getShapePerCTA(dstLayout);
unsigned pad = std::max(inVec, outVec); unsigned pad = std::max(inVec, outVec);
for (unsigned d = 0; d < rank; ++d) { for (unsigned d = 0; d < rank; ++d) {
paddedRepShape[d] = std::max( paddedRepShape[d] =
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)), std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
std::min<unsigned>(dstTy.getShape()[d], getShapePerCTA(dstLayout, d))); std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
} }
unsigned paddedDim = 1; unsigned paddedDim = 1;
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) { if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {

View File

@@ -65,7 +65,7 @@ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
DimVectorT retContiguity; DimVectorT retContiguity;
DimVectorT retDivisibility; DimVectorT retDivisibility;
DimVectorT retConstancy; DimVectorT retConstancy;
for (size_t d = 0; d < lhs.getRank(); d++) { for (size_t d = 0; d < lhs.getRank(); ++d) {
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
retDivisibility.push_back( retDivisibility.push_back(
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
@@ -87,7 +87,7 @@ AxisInfo AxisInfoAnalysis::visitBinaryOp(
AxisInfo::DimVectorT newContiguity; AxisInfo::DimVectorT newContiguity;
AxisInfo::DimVectorT newDivisibility; AxisInfo::DimVectorT newDivisibility;
AxisInfo::DimVectorT newConstancy; AxisInfo::DimVectorT newConstancy;
for (size_t d = 0; d < rank; d++) { for (size_t d = 0; d < rank; ++d) {
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d)); newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d)); newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d)); newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
@@ -166,7 +166,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
AxisInfo::DimVectorT contiguity; AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility; AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy; AxisInfo::DimVectorT constancy;
for (size_t d = 0; d < retTy.getRank(); d++) { for (size_t d = 0; d < retTy.getRank(); ++d) {
contiguity.push_back(1); contiguity.push_back(1);
divisibility.push_back(opInfo.getDivisibility(0)); divisibility.push_back(opInfo.getDivisibility(0));
constancy.push_back(retTy.getShape()[d]); constancy.push_back(retTy.getShape()[d]);
@@ -202,7 +202,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
AxisInfo::DimVectorT contiguity; AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility; AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy; AxisInfo::DimVectorT constancy;
for (size_t d = 0; d < retTy.getRank(); d++) { for (size_t d = 0; d < retTy.getRank(); ++d) {
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
divisibility.push_back(opInfo.getDivisibility(d)); divisibility.push_back(opInfo.getDivisibility(d));
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1); constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);

View File

@@ -1,4 +1,6 @@
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h" #include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include <sstream> // unify to llvm::raw_string_ostream ? #include <sstream> // unify to llvm::raw_string_ostream ?
@@ -10,7 +12,7 @@ std::string strJoin(llvm::ArrayRef<std::string> strs,
llvm::StringRef delimiter) { llvm::StringRef delimiter) {
std::string osStr; std::string osStr;
llvm::raw_string_ostream os(osStr); llvm::raw_string_ostream os(osStr);
for (size_t i = 0; !strs.empty() && i < strs.size() - 1; i++) for (size_t i = 0; !strs.empty() && i < strs.size() - 1; ++i)
os << strs[i] << delimiter; os << strs[i] << delimiter;
if (!strs.empty()) if (!strs.empty())
os << strs.back(); os << strs.back();
@@ -74,6 +76,25 @@ SmallVector<PTXBuilder::Operand *, 4> PTXBuilder::getAllArgs() const {
return res; return res;
} }
mlir::Value PTXBuilder::launch(ConversionPatternRewriter &rewriter,
Location loc, Type resTy, bool hasSideEffect,
bool isAlignStack,
ArrayRef<Attribute> attrs) const {
auto *ctx = rewriter.getContext();
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
loc, resTy, getAllMLIRArgs(), // operands
dump(), // asm_string
getConstraints(), // constraints
hasSideEffect, // has_side_effects
isAlignStack, // is_align_stack
LLVM::AsmDialectAttr::get(ctx,
LLVM::AsmDialect::AD_ATT), // asm_dialect
ArrayAttr::get(ctx, attrs) // operand_attrs
);
return inlineAsm.getRes();
}
std::string PTXInstr::Operand::dump() const { std::string PTXInstr::Operand::dump() const {
if (repr) if (repr)
return repr(idx); return repr(idx);
@@ -151,5 +172,6 @@ PTXInstrExecution::getArgList() const {
} }
return args; return args;
} }
} // namespace triton } // namespace triton
} // namespace mlir } // namespace mlir

View File

@@ -50,10 +50,28 @@ Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
IntegerAttr::get(i32ty, v)); IntegerAttr::get(i32ty, v));
} }
// Create a index type constant.
Value createIndexConstant(OpBuilder &builder, Location loc,
TypeConverter *converter, int64_t value) {
Type ty = converter->convertType(builder.getIndexType());
return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value));
}
// Create an integer constant of \param width bits.
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
int64_t value) {
Type ty = builder.getIntegerType(width);
return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value));
}
// Add other specification if needed... // Add other specification if needed...
} // namespace } // namespace
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__) #define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__) #define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__) #define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
@@ -68,10 +86,19 @@ Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
rewriter.create<LLVM::InsertElementOp>(loc, __VA_ARGS__) rewriter.create<LLVM::InsertElementOp>(loc, __VA_ARGS__)
#define extract_element(...) \ #define extract_element(...) \
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__) rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__) #define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
#define i32_ty rewriter.getIntegerType(32)
#define vec_ty(type, num) VectorType::get(num, type)
// Creator for constant
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) #define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
#define i32_ty() rewriter.getIntegerType(32) #define int_val(width, val) \
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
#define idx_val(...) \
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
__VA_ARGS__)
} // namespace LLVM } // namespace LLVM
} // namespace mlir } // namespace mlir
@@ -215,7 +242,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
// Set an attribute for maxntidx, it could be used in latter LLVM codegen // Set an attribute for maxntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata. // for `nvvm.annotation` metadata.
newFuncOp->setAttr(NVVMMetadataField::MaxNTid, newFuncOp->setAttr(NVVMMetadataField::MaxNTid,
rewriter.getIntegerAttr(i32_ty(), 32 * NumWarps)); rewriter.getIntegerAttr(i32_ty, 32 * NumWarps));
rewriter.eraseOp(funcOp); rewriter.eraseOp(funcOp);
return success(); return success();
@@ -247,19 +274,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
} }
}; };
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) {
return builder.create<LLVM::ConstantOp>(
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
}
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
LLVMTypeConverter *converter, Type ty,
int64_t value) {
return builder.create<LLVM::ConstantOp>(loc, converter->convertType(ty),
builder.getIntegerAttr(ty, value));
}
Value getStructFromElements(Location loc, ValueRange resultVals, Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
Type structType) { Type structType) {
@@ -272,42 +286,36 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
} }
template <typename T> template <typename T>
static SmallVector<T> getMultiDimIndex(T linear_index, ArrayRef<T> shape) { static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
// sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1} // shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
size_t rank = shape.size(); size_t rank = shape.size();
T acc_mul = 1; T accMul = product(shape.drop_front());
for (size_t i = 1; i < rank; ++i) { T linearRemain = linearIndex;
acc_mul *= shape[i]; SmallVector<T> multiDimIndex(rank);
}
T linear_remain = linear_index;
SmallVector<T> multidim_index(rank);
for (size_t i = 0; i < rank; ++i) { for (size_t i = 0; i < rank; ++i) {
multidim_index[i] = linear_remain / acc_mul; multiDimIndex[i] = linearRemain / accMul;
linear_remain = linear_remain % acc_mul; linearRemain = linearRemain % accMul;
if (i != (rank - 1)) { if (i != (rank - 1)) {
acc_mul = acc_mul / shape[i + 1]; accMul = accMul / shape[i + 1];
} }
} }
return multidim_index; return multiDimIndex;
} }
template <typename T> template <typename T>
static T getLinearIndex(ArrayRef<T> multidim_index, ArrayRef<T> shape) { static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
assert(multidim_index.size() == shape.size()); assert(multiDimIndex.size() == shape.size());
// sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1} // shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
size_t rank = shape.size(); size_t rank = shape.size();
T acc_mul = 1; T accMul = product(shape.drop_front());
for (size_t i = 1; i < rank; ++i) { T linearIndex = 0;
acc_mul *= shape[i];
}
T linear_index = 0;
for (size_t i = 0; i < rank; ++i) { for (size_t i = 0; i < rank; ++i) {
linear_index += multidim_index[i] * acc_mul; linearIndex += multiDimIndex[i] * accMul;
if (i != (rank - 1)) { if (i != (rank - 1)) {
acc_mul = acc_mul / shape[i + 1]; accMul = accMul / shape[i + 1];
} }
} }
return linear_index; return linearIndex;
} }
struct ConvertTritonGPUOpToLLVMPatternBase { struct ConvertTritonGPUOpToLLVMPatternBase {
@@ -352,23 +360,15 @@ public:
return threadId; return threadId;
} }
Value createIndexConst(ConversionPatternRewriter &rewriter, Location loc, // Convert an \param index to a multi-dim coordinate given \param shape and
int64_t value) const { // \param order.
return rewriter.create<LLVM::ConstantOp>(
loc, this->getTypeConverter()->getIndexType(),
rewriter.getIntegerAttr(rewriter.getIndexType(), value));
}
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter, SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear, Location loc, Value linear,
ArrayRef<unsigned> shape, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) const { ArrayRef<unsigned> order) const {
unsigned rank = shape.size(); unsigned rank = shape.size();
assert(rank == order.size()); assert(rank == order.size());
SmallVector<unsigned> reordered(rank); auto reordered = reorder(shape, order);
for (unsigned i = 0; i < rank; ++i) {
reordered[i] = shape[order[i]];
}
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered); auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
SmallVector<Value> multiDim(rank); SmallVector<Value> multiDim(rank);
for (unsigned i = 0; i < rank; ++i) { for (unsigned i = 0; i < rank; ++i) {
@@ -388,9 +388,7 @@ public:
} else { } else {
Value remained = linear; Value remained = linear;
for (auto &&en : llvm::enumerate(llvm::reverse(shape.drop_front()))) { for (auto &&en : llvm::enumerate(llvm::reverse(shape.drop_front()))) {
Value dimSize = createIndexAttrConstant( Value dimSize = idx_val(en.value());
rewriter, loc, this->getTypeConverter()->getIndexType(),
en.value());
multiDim[rank - 1 - en.index()] = urem(remained, dimSize); multiDim[rank - 1 - en.index()] = urem(remained, dimSize);
remained = udiv(remained, dimSize); remained = udiv(remained, dimSize);
} }
@@ -402,20 +400,19 @@ public:
Value linearize(ConversionPatternRewriter &rewriter, Location loc, Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const { ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
int rank = multiDim.size(); int rank = multiDim.size();
Value linear = createIndexAttrConstant( Value linear = idx_val(0);
rewriter, loc, this->getTypeConverter()->getIndexType(), 0);
if (rank > 0) { if (rank > 0) {
linear = multiDim.front(); linear = multiDim.front();
for (auto &&z : llvm::zip(multiDim.drop_front(), shape.drop_front())) { for (auto [dim, shape] :
Value dimSize = createIndexAttrConstant( llvm::zip(multiDim.drop_front(), shape.drop_front())) {
rewriter, loc, this->getTypeConverter()->getIndexType(), Value dimSize = idx_val(shape);
std::get<1>(z)); linear = add(mul(linear, dimSize), dim);
linear = add(mul(linear, dimSize), std::get<0>(z));
} }
} }
return linear; return linear;
} }
// Get an index-base for each dimension for a \param blocked_layout.
SmallVector<Value> SmallVector<Value>
emitBaseIndexForBlockedLayout(Location loc, emitBaseIndexForBlockedLayout(Location loc,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
@@ -423,7 +420,7 @@ public:
ArrayRef<int64_t> shape) const { ArrayRef<int64_t> shape) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType(); auto llvmIndexTy = this->getTypeConverter()->getIndexType();
Value threadId = getThreadId(rewriter, loc); Value threadId = getThreadId(rewriter, loc);
Value warpSize = createIndexAttrConstant(rewriter, loc, llvmIndexTy, 32); Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize); Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize); Value warpId = udiv(threadId, warpSize);
auto sizePerThread = blocked_layout.getSizePerThread(); auto sizePerThread = blocked_layout.getSizePerThread();
@@ -444,19 +441,13 @@ public:
unsigned maxWarps = unsigned maxWarps =
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]); ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
unsigned maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]); unsigned maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
multiDimWarpId[k] = multiDimWarpId[k] = urem(multiDimWarpId[k], idx_val(maxWarps));
urem(multiDimWarpId[k], multiDimThreadId[k] = urem(multiDimThreadId[k], idx_val(maxThreads));
createIndexAttrConstant(rewriter, loc, llvmIndexTy, maxWarps));
multiDimThreadId[k] =
urem(multiDimThreadId[k],
createIndexAttrConstant(rewriter, loc, llvmIndexTy, maxThreads));
// multiDimBase[k] = (multiDimThreadId[k] + // multiDimBase[k] = (multiDimThreadId[k] +
// multiDimWarpId[k] * threadsPerWarp[k]) * // multiDimWarpId[k] * threadsPerWarp[k]) *
// sizePerThread[k]; // sizePerThread[k];
Value threadsPerWarpK = createIndexAttrConstant( Value threadsPerWarpK = idx_val(threadsPerWarp[k]);
rewriter, loc, llvmIndexTy, threadsPerWarp[k]); Value sizePerThreadK = idx_val(sizePerThread[k]);
Value sizePerThreadK =
createIndexAttrConstant(rewriter, loc, llvmIndexTy, sizePerThread[k]);
multiDimBase[k] = multiDimBase[k] =
mul(sizePerThreadK, add(multiDimThreadId[k], mul(sizePerThreadK, add(multiDimThreadId[k],
mul(multiDimWarpId[k], threadsPerWarpK))); mul(multiDimWarpId[k], threadsPerWarpK)));
@@ -496,25 +487,22 @@ public:
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) { if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
SmallVector<int64_t> paddedShape(rank + 1); SmallVector<int64_t> paddedShape(rank + 1);
for (unsigned d = 0; d < rank + 1; ++d) { for (unsigned d = 0; d < rank + 1; ++d) {
if (d < dim) { if (d < dim)
paddedShape[d] = shape[d]; paddedShape[d] = shape[d];
} else if (d == dim) { else if (d == dim)
paddedShape[d] = 1; paddedShape[d] = 1;
} else { else
paddedShape[d] = shape[d - 1]; paddedShape[d] = shape[d - 1];
} }
}
auto paddedIndices = emitIndicesForBlockedLayout( auto paddedIndices = emitIndicesForBlockedLayout(
loc, rewriter, blockedParent, paddedShape); loc, rewriter, blockedParent, paddedShape);
unsigned numIndices = paddedIndices.size(); unsigned numIndices = paddedIndices.size();
SmallVector<SmallVector<Value>> resultIndices(numIndices); SmallVector<SmallVector<Value>> resultIndices(numIndices);
for (unsigned i = 0; i < numIndices; ++i) { for (unsigned i = 0; i < numIndices; ++i)
for (unsigned d = 0; d < rank + 1; ++d) { for (unsigned d = 0; d < rank + 1; ++d)
if (d != dim) { if (d != dim)
resultIndices[i].push_back(paddedIndices[i][d]); resultIndices[i].push_back(paddedIndices[i][d]);
}
}
}
return resultIndices; return resultIndices;
} else if (auto sliceParent = parent.dyn_cast<SliceEncodingAttr>()) { } else if (auto sliceParent = parent.dyn_cast<SliceEncodingAttr>()) {
@@ -529,7 +517,8 @@ public:
} }
} }
// Emit indices calculation within each ConversionPattern // Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix.
// TODO: [goostavz] Double confirm the redundant indices calculations will // TODO: [goostavz] Double confirm the redundant indices calculations will
// be eliminated in the consequent MLIR/LLVM optimization. We might // be eliminated in the consequent MLIR/LLVM optimization. We might
// implement a indiceCache if necessary. // implement a indiceCache if necessary.
@@ -542,23 +531,16 @@ public:
auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
unsigned rank = shape.size(); unsigned rank = shape.size();
SmallVector<unsigned> shapePerCTA(rank); SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
for (unsigned k = 0; k < rank; ++k) {
shapePerCTA[k] = sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k];
}
// step 1, delinearize threadId to get the base index // step 1, delinearize threadId to get the base index
auto multiDimBase = auto multiDimBase =
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
// step 2, get offset of each element // step 2, get offset of each element
unsigned elemsPerThread = 1; unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape);
SmallVector<SmallVector<unsigned>> offset(rank); SmallVector<SmallVector<unsigned>> offset(rank);
SmallVector<unsigned> multiDimElemsPerThread(rank);
for (unsigned k = 0; k < rank; ++k) { for (unsigned k = 0; k < rank; ++k) {
multiDimElemsPerThread[k] =
ceil<unsigned>(shape[k], shapePerCTA[k]) * sizePerThread[k];
elemsPerThread *= multiDimElemsPerThread[k];
// 1 block in minimum if shape[k] is less than shapePerCTA[k] // 1 block in minimum if shape[k] is less than shapePerCTA[k]
for (unsigned blockOffset = 0; for (unsigned blockOffset = 0;
blockOffset < ceil<unsigned>(shape[k], shapePerCTA[k]); blockOffset < ceil<unsigned>(shape[k], shapePerCTA[k]);
@@ -574,34 +556,29 @@ public:
threadsPerWarp[k] + threadsPerWarp[k] +
threadOffset * sizePerThread[k] + elemOffset); threadOffset * sizePerThread[k] + elemOffset);
} }
// step 3, add offset to base, and reorder the sequence of indices, // step 3, add offset to base, and reorder the sequence of indices to
// to guarantee that elems in a same sizePerThread are adjacent in // guarantee that elems in the same sizePerThread are adjacent in order
// order SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread); SmallVector<Value>(rank));
unsigned accumSizePerThread = unsigned totalSizePerThread = product<unsigned>(sizePerThread);
std::accumulate(sizePerThread.begin(), sizePerThread.end(), 1,
std::multiplies<unsigned>());
SmallVector<unsigned> threadsPerDim(rank); SmallVector<unsigned> threadsPerDim(rank);
for (unsigned k = 0; k < rank; ++k) { for (unsigned k = 0; k < rank; ++k)
threadsPerDim[k] = ceil<unsigned>(shape[k], sizePerThread[k]); threadsPerDim[k] = ceil<unsigned>(shape[k], sizePerThread[k]);
}
for (unsigned n = 0; n < elemsPerThread; ++n) { for (unsigned n = 0; n < elemsPerThread; ++n) {
unsigned linearNanoTileId = n / accumSizePerThread; unsigned linearNanoTileId = n / totalSizePerThread;
unsigned linearElemsInNanoTileId = n % accumSizePerThread; unsigned linearNanoTileElemId = n % totalSizePerThread;
SmallVector<unsigned> multiDimNanoTileId = SmallVector<unsigned> multiDimNanoTileId =
getMultiDimIndex<unsigned>(linearNanoTileId, threadsPerDim); getMultiDimIndex<unsigned>(linearNanoTileId, threadsPerDim);
SmallVector<unsigned> multiElemsInNanoTileId = SmallVector<unsigned> multiDimNanoTileElemId =
getMultiDimIndex<unsigned>(linearElemsInNanoTileId, sizePerThread); getMultiDimIndex<unsigned>(linearNanoTileElemId, sizePerThread);
multiDimIdx[n].resize(rank);
for (unsigned k = 0; k < rank; ++k) { for (unsigned k = 0; k < rank; ++k) {
unsigned reorderedMultiDimId = unsigned reorderedMultiDimId =
multiDimNanoTileId[k] * multiDimNanoTileId[k] *
(sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) + (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) +
multiElemsInNanoTileId[k]; multiDimNanoTileElemId[k];
multiDimIdx[n][k] = multiDimIdx[n][k] =
add(multiDimBase[k], add(multiDimBase[k], idx_val(offset[k][reorderedMultiDimId]));
createIndexAttrConstant(rewriter, loc, llvmIndexTy,
offset[k][reorderedMultiDimId]));
} }
} }
@@ -617,7 +594,7 @@ public:
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId); size_t offset = allocation->getOffset(bufferId);
auto llvmIndexTy = this->getTypeConverter()->getIndexType(); auto llvmIndexTy = this->getTypeConverter()->getIndexType();
Value offVal = createIndexAttrConstant(rewriter, loc, llvmIndexTy, offset); Value offVal = idx_val(offset);
Value base = gep(ptrTy, smem, offVal); Value base = gep(ptrTy, smem, offVal);
return base; return base;
} }
@@ -636,19 +613,17 @@ protected:
Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
TypeConverter *typeConverter, TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) { ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>(); auto tensorTy = resType.cast<RankedTensorType>();
auto layout = tensorTy.getEncoding(); auto layout = tensorTy.getEncoding();
auto srcType = typeConverter->convertType(elemType); auto srcType = typeConverter->convertType(elemType);
auto llSrc = bit_cast(srcType, constVal); auto llSrc = bit_cast(srcType, constVal);
size_t numElemsPerThread = getElemsPerThread(layout, tensorTy.getShape()); size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
llvm::SmallVector<Value, 4> elems(numElemsPerThread, llSrc); llvm::SmallVector<Value, 4> elems(elemsPerThread, llSrc);
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType); llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
auto structTy = auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
auto llStruct = getStructFromElements(loc, elems, rewriter, structTy); return getStructFromElements(loc, elems, rewriter, structTy);
return llStruct;
} }
struct SplatOpConversion struct SplatOpConversion
@@ -745,7 +720,7 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
assert(layout && "unexpected layout in getLayout"); assert(layout && "unexpected layout in getLayout");
auto shape = ty.getShape(); auto shape = ty.getShape();
unsigned valueElems = layout.getElemsPerThread(shape); unsigned valueElems = layout.getElemsPerThread(shape);
return std::make_tuple(layout, valueElems); return {layout, valueElems};
} }
unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const { unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const {
@@ -864,14 +839,14 @@ struct StoreOpConversion
llvm::SmallVector<std::string> asmArgs; llvm::SmallVector<std::string> asmArgs;
Type valArgTy = IntegerType::get(ctx, width); Type valArgTy = IntegerType::get(ctx, width);
auto wordTy = VectorType::get(wordNElems, valueElemTy); auto wordTy = vec_ty(valueElemTy, wordNElems);
auto *asmArgList = ptxBuilder.newListOperand(); auto *asmArgList = ptxBuilder.newListOperand();
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) { for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
// llWord is a width-len composition // llWord is a width-len composition
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy); Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
// Insert each value element to the composition // Insert each value element to the composition
for (int elemIdx = 0; elemIdx < wordNElems; elemIdx++) { for (int elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx; const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
assert(elemOffset < valueElems.size()); assert(elemOffset < valueElems.size());
Value elem = valueElems[elemOffset]; Value elem = valueElems[elemOffset];
@@ -894,10 +869,7 @@ struct StoreOpConversion
// TODO(Superjomn) Need to check masks before vectorize the load for all // TODO(Superjomn) Need to check masks before vectorize the load for all
// the values share one predicate? Here assume all the mask values are // the values share one predicate? Here assume all the mask values are
// the same. // the same.
Value maskVal = Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
llMask ? maskElems[vecStart]
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
rewriter.getIntegerType(1), 1);
ptxStoreInstr.global().b(width).v(nWords); ptxStoreInstr.global().b(width).v(nWords);
auto *asmAddr = auto *asmAddr =
@@ -906,22 +878,12 @@ struct StoreOpConversion
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b"); ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1)); Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()}); llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
for (int i = 0; i < nWords; i++) for (int i = 0; i < nWords; ++i)
argTys.push_back(valArgTy); argTys.push_back(valArgTy);
auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx); auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx);
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>( ptxBuilder.launch(rewriter, loc, ASMReturnTy);
loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands
ptxBuilder.dump(), // asm_string
ptxBuilder.getConstraints(), // constraints
// TODO(Superjomn) determine the side effect.
true, // has_side_effects
false, // is_align_stack
LLVM::AsmDialectAttr::get(ctx,
LLVM::AsmDialect::AD_ATT), // asm_dialect
ArrayAttr::get(ctx, {}) // operand_attrs
);
} }
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();
@@ -1183,10 +1145,7 @@ struct LoadOpConversion
// TODO(Superjomn) Need to check masks before vectorize the load for all // TODO(Superjomn) Need to check masks before vectorize the load for all
// the values share one predicate? Here assume all the mask values are // the values share one predicate? Here assume all the mask values are
// the same. // the same.
Value pred = Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
mask ? maskElems[vecStart]
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
rewriter.getIntegerType(1), 1);
const std::string readConstraint = const std::string readConstraint =
(width == 64) ? "l" : ((width == 32) ? "r" : "c"); (width == 64) ? "l" : ((width == 32) ? "r" : "c");
@@ -1195,7 +1154,7 @@ struct LoadOpConversion
// prepare asm operands // prepare asm operands
auto *dstsOpr = ptxBuilder.newListOperand(); auto *dstsOpr = ptxBuilder.newListOperand();
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) { for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
dstsOpr->listAppend(opr); dstsOpr->listAppend(opr);
} }
@@ -1228,7 +1187,7 @@ struct LoadOpConversion
SmallVector<Value> others; SmallVector<Value> others;
if (other) { if (other) {
for (size_t ii = 0; ii < nWords; ii++) { for (size_t ii = 0; ii < nWords; ++ii) {
PTXInstr &mov = *ptxBuilder.create<>("mov"); PTXInstr &mov = *ptxBuilder.create<>("mov");
mov.o("u", width); mov.o("u", width);
@@ -1236,7 +1195,7 @@ struct LoadOpConversion
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size); auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy); Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (size_t s = 0; s < size; s++) { for (size_t s = 0; s < size; ++s) {
Value falseVal = otherElems[vecStart + ii * size + s]; Value falseVal = otherElems[vecStart + ii * size + s];
Value sVal = createIndexAttrConstant( Value sVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), s); rewriter, loc, this->getTypeConverter()->getIndexType(), s);
@@ -1267,20 +1226,13 @@ struct LoadOpConversion
// TODO: if (has_l2_evict_policy) // TODO: if (has_l2_evict_policy)
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT); LLVM::AsmDialect::AD_ATT);
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>( Value ret = ptxBuilder.launch(rewriter, loc, retTy);
loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(),
/*asm_string=*/ptxBuilder.dump(),
/*constraints=*/ptxBuilder.getConstraints(),
/*has_side_effects=*/true,
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
Value ret = inlineAsmOp.getResult(0);
// --- // ---
// extract and store return values // extract and store return values
// --- // ---
SmallVector<Value> rets; SmallVector<Value> rets;
for (unsigned int ii = 0; ii < nWords; ii++) { for (unsigned int ii = 0; ii < nWords; ++ii) {
Value curr; Value curr;
if (retTy.isa<LLVM::LLVMStructType>()) { if (retTy.isa<LLVM::LLVMStructType>()) {
curr = extract_val(IntegerType::get(getContext(), width), ret, curr = extract_val(IntegerType::get(getContext(), width), ret,
@@ -1293,8 +1245,8 @@ struct LoadOpConversion
curr); curr);
rets.push_back(curr); rets.push_back(curr);
} }
int tmp = (width / valueElemNbits); int tmp = width / valueElemNbits;
for (size_t ii = 0; ii < vec; ii++) { for (size_t ii = 0; ii < vec; ++ii) {
Value vecIdx = createIndexAttrConstant( Value vecIdx = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp); rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx); Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx);
@@ -1480,6 +1432,7 @@ public:
(!dstLayout.isa<BlockedEncodingAttr>() && (!dstLayout.isa<BlockedEncodingAttr>() &&
!dstLayout.isa<MmaEncodingAttr>())) { !dstLayout.isa<MmaEncodingAttr>())) {
// TODO: to be implemented // TODO: to be implemented
llvm::errs() << "Unsupported ConvertLayout found";
return failure(); return failure();
} }
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
@@ -1494,11 +1447,11 @@ public:
SmallVector<unsigned> outNumCTAsEachRep(rank); SmallVector<unsigned> outNumCTAsEachRep(rank);
SmallVector<unsigned> inNumCTAs(rank); SmallVector<unsigned> inNumCTAs(rank);
SmallVector<unsigned> outNumCTAs(rank); SmallVector<unsigned> outNumCTAs(rank);
auto srcShapePerCTA = getShapePerCTA(srcLayout);
auto dstShapePerCTA = getShapePerCTA(dstLayout);
for (unsigned d = 0; d < rank; ++d) { for (unsigned d = 0; d < rank; ++d) {
unsigned inPerCTA = unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
std::min(unsigned(shape[d]), getShapePerCTA(srcLayout, d)); unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
unsigned outPerCTA =
std::min(unsigned(shape[d]), getShapePerCTA(dstLayout, d));
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA); unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA); numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA; inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
@@ -1579,9 +1532,8 @@ private:
auto accumSizePerThread = product<unsigned>(sizePerThread); auto accumSizePerThread = product<unsigned>(sizePerThread);
auto llvmIndexTy = getTypeConverter()->getIndexType(); auto llvmIndexTy = getTypeConverter()->getIndexType();
SmallVector<unsigned> numCTAs(rank); SmallVector<unsigned> numCTAs(rank);
SmallVector<unsigned> shapePerCTA(rank); auto shapePerCTA = getShapePerCTA(layout);
for (unsigned d = 0; d < rank; ++d) { for (unsigned d = 0; d < rank; ++d) {
shapePerCTA[d] = getShapePerCTA(layout, d);
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]); numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
} }
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType()); auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
@@ -1603,16 +1555,15 @@ private:
Value warpSize = createIndexAttrConstant( Value warpSize = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), 32); rewriter, loc, this->getTypeConverter()->getIndexType(), 32);
Value laneId = rewriter.create<LLVM::URemOp>(loc, threadId, warpSize); Value laneId = rewriter.create<LLVM::URemOp>(loc, threadId, warpSize);
Value fourVal = createIndexConst(rewriter, loc, 4); Value fourVal = idx_val(4);
mmaGrpId = rewriter.create<LLVM::UDivOp>(loc, laneId, fourVal); mmaGrpId = rewriter.create<LLVM::UDivOp>(loc, laneId, fourVal);
mmaGrpIdP8 = rewriter.create<LLVM::AddOp>( mmaGrpIdP8 = rewriter.create<LLVM::AddOp>(loc, mmaGrpId, idx_val(8));
loc, mmaGrpId, createIndexConst(rewriter, loc, 8));
Value mmaThreadIdInGrp = Value mmaThreadIdInGrp =
rewriter.create<LLVM::URemOp>(loc, laneId, fourVal); rewriter.create<LLVM::URemOp>(loc, laneId, fourVal);
mmaThreadIdInGrpM2 = rewriter.create<LLVM::MulOp>( mmaThreadIdInGrpM2 =
loc, mmaThreadIdInGrp, createIndexConst(rewriter, loc, 2)); rewriter.create<LLVM::MulOp>(loc, mmaThreadIdInGrp, idx_val(2));
mmaThreadIdInGrpM2P1 = rewriter.create<LLVM::AddOp>( mmaThreadIdInGrpM2P1 =
loc, mmaThreadIdInGrpM2, createIndexConst(rewriter, loc, 1)); rewriter.create<LLVM::AddOp>(loc, mmaThreadIdInGrpM2, idx_val(1));
} }
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
auto multiDimCTAInRepId = auto multiDimCTAInRepId =
@@ -1654,7 +1605,7 @@ private:
reorder<unsigned>(paddedRepShape, outOrd)); reorder<unsigned>(paddedRepShape, outOrd));
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset); Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = VectorType::get(vec, llvmElemTy); auto vecTy = vec_ty(llvmElemTy, vec);
ptr = bit_cast(LLVM::LLVMPointerType::get(vecTy, 3), ptr); ptr = bit_cast(LLVM::LLVMPointerType::get(vecTy, 3), ptr);
if (stNotRd) { if (stNotRd) {
Value valVec = rewriter.create<LLVM::UndefOp>(loc, vecTy); Value valVec = rewriter.create<LLVM::UndefOp>(loc, vecTy);
@@ -1665,9 +1616,9 @@ private:
vecTy, valVec, vecTy, valVec,
vals[elemId + linearCTAId * accumSizePerThread + v], vVal); vals[elemId + linearCTAId * accumSizePerThread + v], vVal);
} }
rewriter.create<LLVM::StoreOp>(loc, valVec, ptr); store(valVec, ptr);
} else { } else {
Value valVec = rewriter.create<LLVM::LoadOp>(loc, ptr); Value valVec = load(ptr);
for (unsigned v = 0; v < vec; ++v) { for (unsigned v = 0; v < vec; ++v) {
Value vVal = createIndexAttrConstant( Value vVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), v); rewriter, loc, getTypeConverter()->getIndexType(), v);
@@ -1682,6 +1633,7 @@ private:
/// ====================== dot codegen begin ========================== /// ====================== dot codegen begin ==========================
// Data loader for mma.16816 instruction.
class MMA16816SmemLoader { class MMA16816SmemLoader {
public: public:
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, int kOrder, MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, int kOrder,
@@ -1689,8 +1641,10 @@ public:
ArrayRef<int> matShape, int perPhase, int maxPhase, ArrayRef<int> matShape, int perPhase, int maxPhase,
int elemBytes, ConversionPatternRewriter &rewriter, int elemBytes, ConversionPatternRewriter &rewriter,
TypeConverter *typeConverter, const Location &loc) TypeConverter *typeConverter, const Location &loc)
: wpt(wpt), order(order), kOrder(kOrder), tileShape(tileShape), : wpt(wpt), order(order.begin(), order.end()), kOrder(kOrder),
instrShape(instrShape), matShape(matShape), perPhase(perPhase), tileShape(tileShape.begin(), tileShape.end()),
instrShape(instrShape.begin(), instrShape.end()),
matShape(matShape.begin(), matShape.end()), perPhase(perPhase),
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter),
typeConverter(typeConverter), loc(loc), ctx(rewriter.getContext()) { typeConverter(typeConverter), loc(loc), ctx(rewriter.getContext()) {
cMatShape = matShape[order[0]]; cMatShape = matShape[order[0]];
@@ -1722,7 +1676,7 @@ public:
loadStrideInMat[kOrder] = loadStrideInMat[kOrder] =
2; // instrShape[kOrder] / matShape[kOrder], always 2 2; // instrShape[kOrder] / matShape[kOrder], always 2
loadStrideInMat[kOrder ^ 1] = loadStrideInMat[kOrder ^ 1] =
wpt * (instrShape[order[1]] / matShape[order[1]]); wpt * (instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]);
pLoadStrideInMat = loadStrideInMat[order[0]]; pLoadStrideInMat = loadStrideInMat[order[0]];
sMatStride = sMatStride =
@@ -1753,8 +1707,6 @@ public:
// Compute the offset to the matrix this thread(indexed by warpOff and lane) // Compute the offset to the matrix this thread(indexed by warpOff and lane)
// mapped to. // mapped to.
SmallVector<Value> computeLdmatrixMatOffs(Value warpId, Value lane) { SmallVector<Value> computeLdmatrixMatOffs(Value warpId, Value lane) {
MLIRContext *ctx = warpId.getContext();
// 4x4 matrices // 4x4 matrices
Value c = urem(lane, i32_val(8)); Value c = urem(lane, i32_val(8));
Value s = udiv(lane, i32_val(8)); // sub-warp-id Value s = udiv(lane, i32_val(8)); // sub-warp-id
@@ -1895,6 +1847,7 @@ public:
int k = matIdx[kOrder]; int k = matIdx[kOrder];
int ptrIdx{-1}; int ptrIdx{-1};
if (canUseLdmatrix) if (canUseLdmatrix)
ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]); ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]);
else if (elemBytes == 4 && needTrans) // tf32 & trans else if (elemBytes == 4 && needTrans) // tf32 & trans
@@ -1904,7 +1857,9 @@ public:
else else
llvm::report_fatal_error("unsupported mma type found"); llvm::report_fatal_error("unsupported mma type found");
// prefetch logic removed here. // The main difference with the original triton code is we removed the
// prefetch-related logic here for the upstream optimizer phase should take
// care with it, and that is transparent in dot conversion.
auto getPtr = [&](int idx) { return ptrs[idx]; }; auto getPtr = [&](int idx) { return ptrs[idx]; };
Value ptr = getPtr(ptrIdx); Value ptr = getPtr(ptrIdx);
@@ -1915,11 +1870,8 @@ public:
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes; matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
PTXBuilder builder; PTXBuilder builder;
auto resArgs = builder.newListOperand();
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a thread. // ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a thread.
for (int i = 0; i < 4; i++) auto resArgs = builder.newListOperand(4, "=r");
resArgs->listAppend(builder.newOperand("=r"));
auto addrArg = builder.newAddrOperand(ptr, "r", sOffset); auto addrArg = builder.newAddrOperand(ptr, "r", sOffset);
auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4") auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4")
@@ -1927,46 +1879,127 @@ public:
.o("shared.b16"); .o("shared.b16");
ldmatrix(resArgs, addrArg); ldmatrix(resArgs, addrArg);
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>( // The result type is 4xi32, each i32 is composed of 2xf16
loc, ldmatrixRetTy, builder.getAllMLIRArgs(), // operands // elements(adjacent two columns in a row)
builder.dump(), // asm_string Value resV4 = builder.launch(rewriter, loc, ldmatrixRetTy);
builder.getConstraints(), // constraints
true, // has_side_effects
false, // is_align_stack
LLVM::AsmDialectAttr::get(ctx,
LLVM::AsmDialect::AD_ATT), // asm_dialect
ArrayAttr::get(ctx, {}) // operand_attrs
);
auto getIntAttr = [&](int v) { auto getIntAttr = [&](int v) {
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty(), v)}); return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
}; };
Value resV4 = inlineAsm.getRes(); // 4xi32, each is composed of 2xf16 Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
// elements(adjacent columns in a row)
Type fp16x2Ty = VectorType::get({2}, type::f16Ty(ctx)); return {extract_val(fp16x2Ty, resV4, getIntAttr(0)),
return std::make_tuple(extract_val(fp16x2Ty, resV4, getIntAttr(0)),
extract_val(fp16x2Ty, resV4, getIntAttr(1)), extract_val(fp16x2Ty, resV4, getIntAttr(1)),
extract_val(fp16x2Ty, resV4, getIntAttr(2)), extract_val(fp16x2Ty, resV4, getIntAttr(2)),
extract_val(fp16x2Ty, resV4, getIntAttr(3))); extract_val(fp16x2Ty, resV4, getIntAttr(3))};
} else if (elemBytes == 4 && } else if (elemBytes == 4 &&
needTrans) { // Use lds.32 to load tf32 matrices needTrans) { // Use lds.32 to load tf32 matrices
assert(false && "Not implemented yet"); Value ptr2 = getPtr(ptrIdx + 1);
} else if (elemBytes == 1 && needTrans) { assert(sMatStride == 1);
assert(false && "Not implemented yet"); int sOffsetElem =
matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride;
int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride;
Value elems[4];
Type elemTy = type::f32Ty(ctx);
if (kOrder == 1) {
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
elems[1] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
elems[2] =
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
elems[3] =
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
} else {
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
elems[2] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
elems[1] =
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
elems[3] =
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
} }
return std::make_tuple(Value{}, Value{}, Value{}, Value{});
return {elems[0], elems[1], elems[2], elems[3]};
} else if (elemBytes == 1 && needTrans) {
std::array<std::array<Value, 4>, 2> ptrs;
ptrs[0] = {
getPtr(ptrIdx),
getPtr(ptrIdx + 1),
getPtr(ptrIdx + 2),
getPtr(ptrIdx + 3),
};
ptrs[1] = {
getPtr(ptrIdx + 4),
getPtr(ptrIdx + 5),
getPtr(ptrIdx + 6),
getPtr(ptrIdx + 7),
};
assert(sMatStride == 1);
int sOffsetElem =
matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride;
int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride;
std::array<Value, 4> i8v4Elems;
std::array<Value, 4> i32Elems;
i8v4Elems.fill(
rewriter.create<LLVM::UndefOp>(loc, vec_ty(type::i8Ty(ctx), 4)));
Value i8Elems[4][4];
Type elemTy = type::i8Ty(ctx);
if (kOrder == 1) {
Value offset = i32_val(sOffsetElem);
for (int i = 0; i < 2; ++i)
for (int j = 0; j < 4; ++j)
i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], offset));
offset = i32_val(sOffsetElem + sOffsetArrElem);
for (int i = 2; i < 4; ++i)
for (int j = 0; j < 4; ++j)
i8Elems[i][j] = load(gep(elemTy, ptrs[i - 2][j], offset));
for (int m = 0; m < 4; ++m) {
for (int e = 0; e < 4; ++e)
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
i8Elems[m][e], i32_val(e));
i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]);
}
} else { // k first
Value offset = i32_val(sOffsetElem);
for (int j = 0; j < 4; ++j)
i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], offset));
for (int j = 0; j < 4; ++j)
i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], offset));
offset = i32_val(sOffsetElem + sOffsetArrElem);
for (int j = 0; j < 4; ++j)
i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], offset));
for (int j = 0; j < 4; ++j)
i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], offset));
for (int m = 0; m < 4; ++m) {
for (int e = 0; e < 4; ++e)
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
i8Elems[m][e], i32_val(e));
i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]);
}
}
return {i32Elems[0], i32Elems[1], i32Elems[2], i32Elems[3]};
}
assert(false && "Invalid smem load");
return {Value{}, Value{}, Value{}, Value{}};
} }
private: private:
int wpt; int wpt;
ArrayRef<uint32_t> order; SmallVector<uint32_t> order;
int kOrder; int kOrder;
ArrayRef<int64_t> tileShape; SmallVector<int64_t> tileShape;
ArrayRef<int> instrShape; SmallVector<int> instrShape;
ArrayRef<int> matShape; SmallVector<int> matShape;
int perPhase; int perPhase;
int maxPhase; int maxPhase;
int elemBytes; int elemBytes;
@@ -2157,8 +2190,8 @@ struct DotOpConversionHelper {
// The type of a matrix that loaded by either a ldmatrix or composed lds. // The type of a matrix that loaded by either a ldmatrix or composed lds.
Type getMatType() const { Type getMatType() const {
Type fp32Ty = type::f32Ty(ctx); Type fp32Ty = type::f32Ty(ctx);
Type fp16x2Ty = VectorType::get({2}, type::f16Ty(ctx)); Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
Type bf16x2Ty = VectorType::get({2}, type::bf16Ty(ctx)); Type bf16x2Ty = vec_ty(type::bf16Ty(ctx), 2);
// floating point types // floating point types
Type fp16x2Pack4Ty = Type fp16x2Pack4Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp16x2Ty)); LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp16x2Ty));
@@ -2167,7 +2200,7 @@ struct DotOpConversionHelper {
Type fp32Pack4Ty = Type fp32Pack4Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp32Ty)); LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp32Ty));
// integer types // integer types
Type i8x4Ty = VectorType::get({4}, type::i8Ty(ctx)); Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4);
Type i8x4Pack4Ty = Type i8x4Pack4Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i8x4Ty)); LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i8x4Ty));
Type i32Pack4Ty = LLVM::LLVMStructType::getLiteral( Type i32Pack4Ty = LLVM::LLVMStructType::getLiteral(
@@ -2189,6 +2222,23 @@ struct DotOpConversionHelper {
return Type{}; return Type{};
} }
Type getLoadElemTy() {
switch (mmaType) {
case TensorCoreType::FP32_FP16_FP16_FP32:
return vec_ty(type::f16Ty(ctx), 2);
case TensorCoreType::FP32_BF16_BF16_FP32:
return vec_ty(type::bf16Ty(ctx), 2);
case TensorCoreType::FP32_TF32_TF32_FP32:
return type::f32Ty(ctx);
case TensorCoreType::INT32_INT8_INT8_INT32:
return type::i32Ty(ctx);
default:
llvm::report_fatal_error("Unsupported mma type found");
}
return Type{};
}
Type getMmaRetType() const { Type getMmaRetType() const {
Type fp32Ty = type::f32Ty(ctx); Type fp32Ty = type::f32Ty(ctx);
Type i32Ty = type::i32Ty(ctx); Type i32Ty = type::i32Ty(ctx);
@@ -2375,9 +2425,10 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
const int numRepN = std::max<int>(dShape[1] / (wpt[1] * mmaInstrN), 1); const int numRepN = std::max<int>(dShape[1] / (wpt[1] * mmaInstrN), 1);
const int numRepK = std::max<int>(NK / mmaInstrK, 1); const int numRepK = std::max<int>(NK / mmaInstrK, 1);
Value head = getThreadId(rewriter, loc); Value _32 = i32_val(32);
Value lane = urem(head, i32_val(32)); Value thread = getThreadId(rewriter, loc);
Value warp = udiv(head, i32_val(32)); Value lane = urem(thread, _32);
Value warp = udiv(thread, _32);
Value warpMN = udiv(warp, i32_val(wpt[0])); Value warpMN = udiv(warp, i32_val(wpt[0]));
Value warpM = urem(warp, i32_val(wpt[0])); Value warpM = urem(warp, i32_val(wpt[0]));
Value warpN = urem(warpMN, i32_val(wpt[1])); Value warpN = urem(warpMN, i32_val(wpt[1]));
@@ -2389,7 +2440,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
std::map<std::pair<unsigned, unsigned>, Value> hb; std::map<std::pair<unsigned, unsigned>, Value> hb;
// the original register_lds2, but discard the prefetch logic. // the original register_lds2, but discard the prefetch logic.
auto ld2 = [&](decltype(ha) &vals, int mn, int k, Value val) { auto ld2 = [](decltype(ha) &vals, int mn, int k, Value val) {
vals[{mn, k}] = val; vals[{mn, k}] = val;
}; };
@@ -2405,6 +2456,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
const int perPhase = sharedLayout.getPerPhase(); const int perPhase = sharedLayout.getPerPhase();
const int maxPhase = sharedLayout.getMaxPhase(); const int maxPhase = sharedLayout.getMaxPhase();
const int elemBytes = tensorTy.getElementTypeBitWidth() / 8; const int elemBytes = tensorTy.getElementTypeBitWidth() / 8;
auto order = sharedLayout.getOrder();
MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder, MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder,
tensorTy.getShape() /*tileShape*/, instrShape, tensorTy.getShape() /*tileShape*/, instrShape,
@@ -2417,34 +2469,56 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
Type smemPtrTy = helper.getShemPtrTy(); Type smemPtrTy = helper.getShemPtrTy();
auto smemBase = getSharedMemoryBase(loc, rewriter, tensor); auto smemBase = getSharedMemoryBase(loc, rewriter, tensor);
for (int i = 0; i < numPtrs; i++) { for (int i = 0; i < numPtrs; ++i) {
ptrs[i] = bit_cast( ptrs[i] = bit_cast(
smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]}))); smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]})));
} }
bool needTrans = kOrder != order[0];
// (a, b) is the coordinate. // (a, b) is the coordinate.
auto load = [&, loader, ptrs, offs](int a, int b) { auto load = [&, loader, ptrs, offs, needTrans](int a, int b) {
auto [ha0, ha1, ha2, ha3] = loader.loadX4( auto [ha0, ha1, ha2, ha3] = loader.loadX4(
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs, (kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
ptrs, helper.getMatType(), helper.getShemPtrTy()); ptrs, helper.getMatType(), helper.getShemPtrTy());
if (!needTrans) {
ld2(vals, a, b, ha0); ld2(vals, a, b, ha0);
ld2(vals, a + 1, b, ha1); ld2(vals, a + 1, b, ha1);
ld2(vals, a, b + 1, ha2); ld2(vals, a, b + 1, ha2);
ld2(vals, a + 1, b + 1, ha3); ld2(vals, a + 1, b + 1, ha3);
} else {
ld2(vals, a, b, ha0);
ld2(vals, a + 1, b, ha2);
ld2(vals, a, b + 1, ha1);
ld2(vals, a + 1, b + 1, ha3);
}
}; };
return load; return load;
}; };
std::function<void(int, int)> loadA = getLoadMatrixFn( std::function<void(int, int)> loadA;
A, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, 1 /*kOrder*/,
{mmaInstrM, mmaInstrK} /*instrShpae*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
std::function<void(int, int)> loadB = getLoadMatrixFn( std::function<void(int, int)> loadB = getLoadMatrixFn(
B, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/, B, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/,
{mmaInstrK, mmaInstrN} /*instrShpae*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
if (aTensorTy.getEncoding()
.dyn_cast<SharedEncodingAttr>()) { // load from smem
loadA = getLoadMatrixFn(A, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
{matShapeM, matShapeK} /*matShape*/,
warpM /*warpId*/, ha /*vals*/);
} else if (auto blockedLayout =
aTensorTy.getEncoding()
.dyn_cast<BlockedEncodingAttr>()) { // load from registers,
// used in gemm fuse
// TODO(Superjomn) Port the logic.
assert(false && "Loading A from register is not supported yet.");
} else {
assert(false && "A's layout is not supported.");
}
const unsigned mStride = numRepN * 2; const unsigned mStride = numRepN * 2;
SmallVector<Value> fc(numRepM * mStride + numRepN * 2); SmallVector<Value> fc(numRepM * mStride + numRepN * 2);
auto callMma = [&](unsigned m, unsigned n, unsigned k) { auto callMma = [&](unsigned m, unsigned n, unsigned k) {
@@ -2452,44 +2526,36 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
auto &mma = *builder.create(helper.getMmaInstr().str()); auto &mma = *builder.create(helper.getMmaInstr().str());
auto retArgs = builder.newListOperand(); auto retArgs = builder.newListOperand(4, "=r");
for (int i = 0; i < 4; ++i)
retArgs->listAppend(builder.newOperand("=r"));
auto aArg0 = builder.newOperand(ha[{m, k}], "r");
auto aArg1 = builder.newOperand(ha[{m + 1, k}], "r");
auto aArg2 = builder.newOperand(ha[{m, k + 1}], "r");
auto aArg3 = builder.newOperand(ha[{m + 1, k}], "r");
auto bArg0 = builder.newOperand(ha[{n, k}], "r"); auto aArgs = builder.newListOperand({
auto bArg1 = builder.newOperand(ha[{n, k + 1}], "r"); {ha[{m, k}], "r"},
{ha[{m + 1, k}], "r"},
{ha[{m, k + 1}], "r"},
{ha[{m + 1, k + 1}], "r"},
});
auto bArgs =
builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}});
// Currently, we only support a SplatLike C. For the other cases, e.g., C in // Currently, we only support a SplatLike C. For the other cases, e.g., C in
// shared layout or blocked layout, we will support them by expanding // shared layout or blocked layout, we will support them by expanding
// convert_layout. // convert_layout.
auto hc = helper.loadSplatLikeC(C, loc, rewriter); auto hc = helper.loadSplatLikeC(C, loc, rewriter);
assert(hc.size() == 4UL && "Only splat-like C is supported now"); assert(hc.size() == 4UL && "Only splat-like C is supported now");
auto cArg0 = builder.newOperand(hc[0], "0"); // reuse the output registers
auto cArg1 = builder.newOperand(hc[1], "1");
auto cArg2 = builder.newOperand(hc[2], "2");
auto cArg3 = builder.newOperand(hc[3], "3");
mma({retArgs, aArg0, aArg1, aArg2, aArg3, bArg0, bArg1, cArg0, cArg1, cArg2, auto cArgs = builder.newListOperand();
cArg3}); for (int i = 0; i < hc.size(); ++i) {
cArgs->listAppend(builder.newOperand(
hc[i], std::to_string(i))); // reuse the output registers
}
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>( mma(retArgs, aArgs, bArgs, cArgs);
loc, helper.getMmaRetType(), builder.getAllMLIRArgs(), // operands
builder.dump(), // asm_string Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType());
builder.getConstraints(), // constraints
true, // has_side_effects
false, // is_align_stack
LLVM::AsmDialectAttr::get(ctx,
LLVM::AsmDialect::AD_ATT), // asm_dialect
ArrayAttr::get(ctx, {}) // operand_attrs
);
auto mmaOut = inlineAsm.getRes();
auto getIntAttr = [&](int v) { auto getIntAttr = [&](int v) {
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty(), v)}); return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
}; };
fc[(m + 0) * mStride + (n * 2 + 0)] = fc[(m + 0) * mStride + (n * 2 + 0)] =
@@ -2504,13 +2570,13 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
// Main program // Main program
for (unsigned k = 0; k < numRepK; k++) { for (unsigned k = 0; k < numRepK; ++k) {
for (unsigned m = 0; m < numRepM; m++) for (unsigned m = 0; m < numRepM; ++m)
loadA(2 * m, 2 * k); loadA(2 * m, 2 * k);
for (unsigned n = 0; n < numRepN; n += 2) for (unsigned n = 0; n < numRepN; n += 2)
loadB(n, 2 * k); loadB(n, 2 * k);
for (unsigned m = 0; m < numRepM; m++) for (unsigned m = 0; m < numRepM; ++m)
for (unsigned n = 0; n < numRepN; n++) { for (unsigned n = 0; n < numRepN; ++n) {
callMma(2 * m, n, 2 * k); callMma(2 * m, n, 2 * k);
} }
} }

View File

@@ -72,26 +72,24 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
} }
} }
unsigned getShapePerCTA(const Attribute &layout, unsigned d) { SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
SmallVector<unsigned> shape;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) { if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getSizePerThread()[d] * for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
shape.push_back(blockedLayout.getSizePerThread()[d] *
blockedLayout.getThreadsPerWarp()[d] * blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d]; blockedLayout.getWarpsPerCTA()[d]);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) { } else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 && assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet"); "mmaLayout version = 1 is not implemented yet");
assert(d < 2 && "Unexpected usage of getShapePerCTA"); return {16 * mmaLayout.getWarpsPerCTA()[0],
if (d == 0) { 8 * mmaLayout.getWarpsPerCTA()[1]};
return 16 * mmaLayout.getWarpsPerCTA()[0];
} else {
// d == 1
return 8 * mmaLayout.getWarpsPerCTA()[1];
}
} else { } else {
assert(0 && "Unimplemented usage of getShapePerCTA"); assert(0 && "Unimplemented usage of getShapePerCTA");
return 0;
} }
};
return shape;
}
SmallVector<unsigned> getOrder(const Attribute &layout) { SmallVector<unsigned> getOrder(const Attribute &layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) { if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
@@ -106,7 +104,7 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
assert(0 && "Unimplemented usage of getOrder"); assert(0 && "Unimplemented usage of getOrder");
return {}; return {};
} }
}; }
} // namespace gpu } // namespace gpu
} // namespace triton } // namespace triton
@@ -180,16 +178,17 @@ SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const { unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size(); size_t rank = shape.size();
assert(rank == getSizePerThread().size() && auto sizePerThread = getSizePerThread();
auto warpsPerCTA = getWarpsPerCTA();
auto threadsPerWarp = getThreadsPerWarp();
assert(rank == sizePerThread.size() &&
"unexpected rank in BlockedEncodingAttr::getElemsPerThread"); "unexpected rank in BlockedEncodingAttr::getElemsPerThread");
SmallVector<unsigned> elemsPerThreadPerDim(rank); SmallVector<unsigned> elemsPerThread(rank);
for (size_t i = 0; i < rank; ++i) { for (size_t i = 0; i < rank; ++i) {
unsigned t = unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i];
getSizePerThread()[i] * getThreadsPerWarp()[i] * getWarpsPerCTA()[i]; elemsPerThread[i] = ceil<unsigned>(shape[i], t) * sizePerThread[i];
elemsPerThreadPerDim[i] =
ceil<unsigned>(shape[i], t) * getSizePerThread()[i];
} }
return product<unsigned>(elemsPerThreadPerDim); return product<unsigned>(elemsPerThread);
} }
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const { unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
@@ -216,11 +215,9 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
} }
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const { unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size(); int threads = product(getWarpsPerCTA());
assert(rank == 2 && "Unexpected rank of mma layout"); int numElem = product(shape);
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2; return numElem / threads;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
return elemsCol * elemsRow;
} }
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const { unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {

View File

@@ -1,5 +1,5 @@
add_triton_ut( add_triton_ut(
NAME TritonAnalysisTests NAME TestTritonAnalysis
SRCS UtilityTest.cpp SRCS UtilityTest.cpp
LIBS TritonAnalysis LIBS TritonAnalysis
) )

View File

@@ -4,11 +4,26 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "triton/Analysis/Utility.h" #include "triton/Analysis/Utility.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
namespace mlir { namespace mlir {
TEST(UtilityTest, DummyTest) { EXPECT_EQ(true, true); } TEST(Analysis, reorder) {
SmallVector<int> shape({10, 20, 30});
{
SmallVector<unsigned> order({2, 1, 0});
auto reordered = reorder<int>(shape, order);
EXPECT_EQ(reordered[0], 30);
EXPECT_EQ(reordered[1], 20);
EXPECT_EQ(reordered[2], 10);
}
{
SmallVector<unsigned> order({1, 0, 2});
auto reordered = reorder<int>(shape, order);
EXPECT_EQ(reordered[0], 20);
EXPECT_EQ(reordered[1], 10);
EXPECT_EQ(reordered[2], 30);
}
}
} // namespace mlir } // namespace mlir

View File

@@ -1,5 +1,5 @@
add_triton_ut( add_triton_ut(
NAME PtxAsmFormatTest NAME TestPtxAsmFormat
SRCS PtxAsmFormatTest.cpp SRCS PtxAsmFormatTest.cpp
LIBS TritonGPUToLLVM LIBS TritonGPUToLLVM
) )