|
|
@@ -50,10 +50,28 @@ Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
|
|
|
|
IntegerAttr::get(i32ty, v));
|
|
|
|
IntegerAttr::get(i32ty, v));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Create a index type constant.
|
|
|
|
|
|
|
|
Value createIndexConstant(OpBuilder &builder, Location loc,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TypeConverter *converter, int64_t value) {
|
|
|
|
|
|
|
|
Type ty = converter->convertType(builder.getIndexType());
|
|
|
|
|
|
|
|
return builder.create<LLVM::ConstantOp>(loc, ty,
|
|
|
|
|
|
|
|
builder.getIntegerAttr(ty, value));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Create an integer constant of \param width bits.
|
|
|
|
|
|
|
|
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
|
|
|
|
|
|
|
int64_t value) {
|
|
|
|
|
|
|
|
Type ty = builder.getIntegerType(width);
|
|
|
|
|
|
|
|
return builder.create<LLVM::ConstantOp>(loc, ty,
|
|
|
|
|
|
|
|
builder.getIntegerAttr(ty, value));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Add other specification if needed...
|
|
|
|
// Add other specification if needed...
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
|
|
|
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
|
|
|
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
|
|
|
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
|
|
|
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
|
|
|
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
|
|
|
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
|
|
@@ -68,10 +86,19 @@ Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
|
|
|
|
rewriter.create<LLVM::InsertElementOp>(loc, __VA_ARGS__)
|
|
|
|
rewriter.create<LLVM::InsertElementOp>(loc, __VA_ARGS__)
|
|
|
|
#define extract_element(...) \
|
|
|
|
#define extract_element(...) \
|
|
|
|
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
|
|
|
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
|
|
|
|
|
|
|
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
|
|
|
|
|
|
|
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
|
|
|
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
|
|
|
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
|
|
|
|
|
|
|
#define i32_ty rewriter.getIntegerType(32)
|
|
|
|
|
|
|
|
#define vec_ty(type, num) VectorType::get(num, type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Creator for constant
|
|
|
|
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
|
|
|
|
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
|
|
|
|
#define i32_ty() rewriter.getIntegerType(32)
|
|
|
|
#define int_val(width, val) \
|
|
|
|
|
|
|
|
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
|
|
|
|
|
|
|
|
#define idx_val(...) \
|
|
|
|
|
|
|
|
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
|
|
|
|
|
|
|
|
__VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace LLVM
|
|
|
|
} // namespace LLVM
|
|
|
|
} // namespace mlir
|
|
|
|
} // namespace mlir
|
|
|
@@ -215,7 +242,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
|
|
|
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
|
|
|
|
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
|
|
|
|
// for `nvvm.annotation` metadata.
|
|
|
|
// for `nvvm.annotation` metadata.
|
|
|
|
newFuncOp->setAttr(NVVMMetadataField::MaxNTid,
|
|
|
|
newFuncOp->setAttr(NVVMMetadataField::MaxNTid,
|
|
|
|
rewriter.getIntegerAttr(i32_ty(), 32 * NumWarps));
|
|
|
|
rewriter.getIntegerAttr(i32_ty, 32 * NumWarps));
|
|
|
|
|
|
|
|
|
|
|
|
rewriter.eraseOp(funcOp);
|
|
|
|
rewriter.eraseOp(funcOp);
|
|
|
|
return success();
|
|
|
|
return success();
|
|
|
@@ -247,19 +274,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
|
|
|
|
|
|
|
Type resultType, int64_t value) {
|
|
|
|
|
|
|
|
return builder.create<LLVM::ConstantOp>(
|
|
|
|
|
|
|
|
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
|
|
|
|
|
|
|
|
LLVMTypeConverter *converter, Type ty,
|
|
|
|
|
|
|
|
int64_t value) {
|
|
|
|
|
|
|
|
return builder.create<LLVM::ConstantOp>(loc, converter->convertType(ty),
|
|
|
|
|
|
|
|
builder.getIntegerAttr(ty, value));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Value getStructFromElements(Location loc, ValueRange resultVals,
|
|
|
|
Value getStructFromElements(Location loc, ValueRange resultVals,
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
|
Type structType) {
|
|
|
|
Type structType) {
|
|
|
@@ -272,42 +286,36 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
static SmallVector<T> getMultiDimIndex(T linear_index, ArrayRef<T> shape) {
|
|
|
|
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
|
|
|
|
// sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1}
|
|
|
|
// shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
|
|
|
|
size_t rank = shape.size();
|
|
|
|
size_t rank = shape.size();
|
|
|
|
T acc_mul = 1;
|
|
|
|
T accMul = product(shape.drop_front());
|
|
|
|
for (size_t i = 1; i < rank; ++i) {
|
|
|
|
T linearRemain = linearIndex;
|
|
|
|
acc_mul *= shape[i];
|
|
|
|
SmallVector<T> multiDimIndex(rank);
|
|
|
|
}
|
|
|
|
|
|
|
|
T linear_remain = linear_index;
|
|
|
|
|
|
|
|
SmallVector<T> multidim_index(rank);
|
|
|
|
|
|
|
|
for (size_t i = 0; i < rank; ++i) {
|
|
|
|
for (size_t i = 0; i < rank; ++i) {
|
|
|
|
multidim_index[i] = linear_remain / acc_mul;
|
|
|
|
multiDimIndex[i] = linearRemain / accMul;
|
|
|
|
linear_remain = linear_remain % acc_mul;
|
|
|
|
linearRemain = linearRemain % accMul;
|
|
|
|
if (i != (rank - 1)) {
|
|
|
|
if (i != (rank - 1)) {
|
|
|
|
acc_mul = acc_mul / shape[i + 1];
|
|
|
|
accMul = accMul / shape[i + 1];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return multidim_index;
|
|
|
|
return multiDimIndex;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
static T getLinearIndex(ArrayRef<T> multidim_index, ArrayRef<T> shape) {
|
|
|
|
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
|
|
|
assert(multidim_index.size() == shape.size());
|
|
|
|
assert(multiDimIndex.size() == shape.size());
|
|
|
|
// sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1}
|
|
|
|
// shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
|
|
|
|
size_t rank = shape.size();
|
|
|
|
size_t rank = shape.size();
|
|
|
|
T acc_mul = 1;
|
|
|
|
T accMul = product(shape.drop_front());
|
|
|
|
for (size_t i = 1; i < rank; ++i) {
|
|
|
|
T linearIndex = 0;
|
|
|
|
acc_mul *= shape[i];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
T linear_index = 0;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < rank; ++i) {
|
|
|
|
for (size_t i = 0; i < rank; ++i) {
|
|
|
|
linear_index += multidim_index[i] * acc_mul;
|
|
|
|
linearIndex += multiDimIndex[i] * accMul;
|
|
|
|
if (i != (rank - 1)) {
|
|
|
|
if (i != (rank - 1)) {
|
|
|
|
acc_mul = acc_mul / shape[i + 1];
|
|
|
|
accMul = accMul / shape[i + 1];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return linear_index;
|
|
|
|
return linearIndex;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
struct ConvertTritonGPUOpToLLVMPatternBase {
|
|
|
|
struct ConvertTritonGPUOpToLLVMPatternBase {
|
|
|
@@ -352,23 +360,15 @@ public:
|
|
|
|
return threadId;
|
|
|
|
return threadId;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Value createIndexConst(ConversionPatternRewriter &rewriter, Location loc,
|
|
|
|
// Convert an \param index to a multi-dim coordinate given \param shape and
|
|
|
|
int64_t value) const {
|
|
|
|
// \param order.
|
|
|
|
return rewriter.create<LLVM::ConstantOp>(
|
|
|
|
|
|
|
|
loc, this->getTypeConverter()->getIndexType(),
|
|
|
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIndexType(), value));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
|
|
|
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
|
|
|
Location loc, Value linear,
|
|
|
|
Location loc, Value linear,
|
|
|
|
ArrayRef<unsigned> shape,
|
|
|
|
ArrayRef<unsigned> shape,
|
|
|
|
ArrayRef<unsigned> order) const {
|
|
|
|
ArrayRef<unsigned> order) const {
|
|
|
|
unsigned rank = shape.size();
|
|
|
|
unsigned rank = shape.size();
|
|
|
|
assert(rank == order.size());
|
|
|
|
assert(rank == order.size());
|
|
|
|
SmallVector<unsigned> reordered(rank);
|
|
|
|
auto reordered = reorder(shape, order);
|
|
|
|
for (unsigned i = 0; i < rank; ++i) {
|
|
|
|
|
|
|
|
reordered[i] = shape[order[i]];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
|
|
|
|
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
|
|
|
|
SmallVector<Value> multiDim(rank);
|
|
|
|
SmallVector<Value> multiDim(rank);
|
|
|
|
for (unsigned i = 0; i < rank; ++i) {
|
|
|
|
for (unsigned i = 0; i < rank; ++i) {
|
|
|
@@ -388,9 +388,7 @@ public:
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
Value remained = linear;
|
|
|
|
Value remained = linear;
|
|
|
|
for (auto &&en : llvm::enumerate(llvm::reverse(shape.drop_front()))) {
|
|
|
|
for (auto &&en : llvm::enumerate(llvm::reverse(shape.drop_front()))) {
|
|
|
|
Value dimSize = createIndexAttrConstant(
|
|
|
|
Value dimSize = idx_val(en.value());
|
|
|
|
rewriter, loc, this->getTypeConverter()->getIndexType(),
|
|
|
|
|
|
|
|
en.value());
|
|
|
|
|
|
|
|
multiDim[rank - 1 - en.index()] = urem(remained, dimSize);
|
|
|
|
multiDim[rank - 1 - en.index()] = urem(remained, dimSize);
|
|
|
|
remained = udiv(remained, dimSize);
|
|
|
|
remained = udiv(remained, dimSize);
|
|
|
|
}
|
|
|
|
}
|
|
|
@@ -402,20 +400,19 @@ public:
|
|
|
|
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
|
|
|
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
|
|
|
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
|
|
|
|
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
|
|
|
|
int rank = multiDim.size();
|
|
|
|
int rank = multiDim.size();
|
|
|
|
Value linear = createIndexAttrConstant(
|
|
|
|
Value linear = idx_val(0);
|
|
|
|
rewriter, loc, this->getTypeConverter()->getIndexType(), 0);
|
|
|
|
|
|
|
|
if (rank > 0) {
|
|
|
|
if (rank > 0) {
|
|
|
|
linear = multiDim.front();
|
|
|
|
linear = multiDim.front();
|
|
|
|
for (auto &&z : llvm::zip(multiDim.drop_front(), shape.drop_front())) {
|
|
|
|
for (auto [dim, shape] :
|
|
|
|
Value dimSize = createIndexAttrConstant(
|
|
|
|
llvm::zip(multiDim.drop_front(), shape.drop_front())) {
|
|
|
|
rewriter, loc, this->getTypeConverter()->getIndexType(),
|
|
|
|
Value dimSize = idx_val(shape);
|
|
|
|
std::get<1>(z));
|
|
|
|
linear = add(mul(linear, dimSize), dim);
|
|
|
|
linear = add(mul(linear, dimSize), std::get<0>(z));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return linear;
|
|
|
|
return linear;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Get an index-base for each dimension for a \param blocked_layout.
|
|
|
|
SmallVector<Value>
|
|
|
|
SmallVector<Value>
|
|
|
|
emitBaseIndexForBlockedLayout(Location loc,
|
|
|
|
emitBaseIndexForBlockedLayout(Location loc,
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
@@ -423,7 +420,7 @@ public:
|
|
|
|
ArrayRef<int64_t> shape) const {
|
|
|
|
ArrayRef<int64_t> shape) const {
|
|
|
|
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
|
|
|
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
|
|
|
Value threadId = getThreadId(rewriter, loc);
|
|
|
|
Value threadId = getThreadId(rewriter, loc);
|
|
|
|
Value warpSize = createIndexAttrConstant(rewriter, loc, llvmIndexTy, 32);
|
|
|
|
Value warpSize = idx_val(32);
|
|
|
|
Value laneId = urem(threadId, warpSize);
|
|
|
|
Value laneId = urem(threadId, warpSize);
|
|
|
|
Value warpId = udiv(threadId, warpSize);
|
|
|
|
Value warpId = udiv(threadId, warpSize);
|
|
|
|
auto sizePerThread = blocked_layout.getSizePerThread();
|
|
|
|
auto sizePerThread = blocked_layout.getSizePerThread();
|
|
|
@@ -444,19 +441,13 @@ public:
|
|
|
|
unsigned maxWarps =
|
|
|
|
unsigned maxWarps =
|
|
|
|
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
|
|
|
|
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
|
|
|
|
unsigned maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
|
|
|
|
unsigned maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
|
|
|
|
multiDimWarpId[k] =
|
|
|
|
multiDimWarpId[k] = urem(multiDimWarpId[k], idx_val(maxWarps));
|
|
|
|
urem(multiDimWarpId[k],
|
|
|
|
multiDimThreadId[k] = urem(multiDimThreadId[k], idx_val(maxThreads));
|
|
|
|
createIndexAttrConstant(rewriter, loc, llvmIndexTy, maxWarps));
|
|
|
|
|
|
|
|
multiDimThreadId[k] =
|
|
|
|
|
|
|
|
urem(multiDimThreadId[k],
|
|
|
|
|
|
|
|
createIndexAttrConstant(rewriter, loc, llvmIndexTy, maxThreads));
|
|
|
|
|
|
|
|
// multiDimBase[k] = (multiDimThreadId[k] +
|
|
|
|
// multiDimBase[k] = (multiDimThreadId[k] +
|
|
|
|
// multiDimWarpId[k] * threadsPerWarp[k]) *
|
|
|
|
// multiDimWarpId[k] * threadsPerWarp[k]) *
|
|
|
|
// sizePerThread[k];
|
|
|
|
// sizePerThread[k];
|
|
|
|
Value threadsPerWarpK = createIndexAttrConstant(
|
|
|
|
Value threadsPerWarpK = idx_val(threadsPerWarp[k]);
|
|
|
|
rewriter, loc, llvmIndexTy, threadsPerWarp[k]);
|
|
|
|
Value sizePerThreadK = idx_val(sizePerThread[k]);
|
|
|
|
Value sizePerThreadK =
|
|
|
|
|
|
|
|
createIndexAttrConstant(rewriter, loc, llvmIndexTy, sizePerThread[k]);
|
|
|
|
|
|
|
|
multiDimBase[k] =
|
|
|
|
multiDimBase[k] =
|
|
|
|
mul(sizePerThreadK, add(multiDimThreadId[k],
|
|
|
|
mul(sizePerThreadK, add(multiDimThreadId[k],
|
|
|
|
mul(multiDimWarpId[k], threadsPerWarpK)));
|
|
|
|
mul(multiDimWarpId[k], threadsPerWarpK)));
|
|
|
@@ -496,25 +487,22 @@ public:
|
|
|
|
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
SmallVector<int64_t> paddedShape(rank + 1);
|
|
|
|
SmallVector<int64_t> paddedShape(rank + 1);
|
|
|
|
for (unsigned d = 0; d < rank + 1; ++d) {
|
|
|
|
for (unsigned d = 0; d < rank + 1; ++d) {
|
|
|
|
if (d < dim) {
|
|
|
|
if (d < dim)
|
|
|
|
paddedShape[d] = shape[d];
|
|
|
|
paddedShape[d] = shape[d];
|
|
|
|
} else if (d == dim) {
|
|
|
|
else if (d == dim)
|
|
|
|
paddedShape[d] = 1;
|
|
|
|
paddedShape[d] = 1;
|
|
|
|
} else {
|
|
|
|
else
|
|
|
|
paddedShape[d] = shape[d - 1];
|
|
|
|
paddedShape[d] = shape[d - 1];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto paddedIndices = emitIndicesForBlockedLayout(
|
|
|
|
auto paddedIndices = emitIndicesForBlockedLayout(
|
|
|
|
loc, rewriter, blockedParent, paddedShape);
|
|
|
|
loc, rewriter, blockedParent, paddedShape);
|
|
|
|
unsigned numIndices = paddedIndices.size();
|
|
|
|
unsigned numIndices = paddedIndices.size();
|
|
|
|
SmallVector<SmallVector<Value>> resultIndices(numIndices);
|
|
|
|
SmallVector<SmallVector<Value>> resultIndices(numIndices);
|
|
|
|
for (unsigned i = 0; i < numIndices; ++i) {
|
|
|
|
for (unsigned i = 0; i < numIndices; ++i)
|
|
|
|
for (unsigned d = 0; d < rank + 1; ++d) {
|
|
|
|
for (unsigned d = 0; d < rank + 1; ++d)
|
|
|
|
if (d != dim) {
|
|
|
|
if (d != dim)
|
|
|
|
resultIndices[i].push_back(paddedIndices[i][d]);
|
|
|
|
resultIndices[i].push_back(paddedIndices[i][d]);
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return resultIndices;
|
|
|
|
return resultIndices;
|
|
|
|
|
|
|
|
|
|
|
|
} else if (auto sliceParent = parent.dyn_cast<SliceEncodingAttr>()) {
|
|
|
|
} else if (auto sliceParent = parent.dyn_cast<SliceEncodingAttr>()) {
|
|
|
@@ -529,7 +517,8 @@ public:
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Emit indices calculation within each ConversionPattern
|
|
|
|
// Emit indices calculation within each ConversionPattern, and returns a
|
|
|
|
|
|
|
|
// [elemsPerThread X rank] index matrix.
|
|
|
|
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
|
|
|
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
|
|
|
// be eliminated in the consequent MLIR/LLVM optimization. We might
|
|
|
|
// be eliminated in the consequent MLIR/LLVM optimization. We might
|
|
|
|
// implement a indiceCache if necessary.
|
|
|
|
// implement a indiceCache if necessary.
|
|
|
@@ -542,23 +531,16 @@ public:
|
|
|
|
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
|
|
|
|
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
|
|
|
|
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
|
|
|
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
|
|
|
unsigned rank = shape.size();
|
|
|
|
unsigned rank = shape.size();
|
|
|
|
SmallVector<unsigned> shapePerCTA(rank);
|
|
|
|
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
|
|
|
|
for (unsigned k = 0; k < rank; ++k) {
|
|
|
|
|
|
|
|
shapePerCTA[k] = sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// step 1, delinearize threadId to get the base index
|
|
|
|
// step 1, delinearize threadId to get the base index
|
|
|
|
auto multiDimBase =
|
|
|
|
auto multiDimBase =
|
|
|
|
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
|
|
|
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
|
|
|
|
|
|
|
|
|
|
|
// step 2, get offset of each element
|
|
|
|
// step 2, get offset of each element
|
|
|
|
unsigned elemsPerThread = 1;
|
|
|
|
unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape);
|
|
|
|
SmallVector<SmallVector<unsigned>> offset(rank);
|
|
|
|
SmallVector<SmallVector<unsigned>> offset(rank);
|
|
|
|
SmallVector<unsigned> multiDimElemsPerThread(rank);
|
|
|
|
|
|
|
|
for (unsigned k = 0; k < rank; ++k) {
|
|
|
|
for (unsigned k = 0; k < rank; ++k) {
|
|
|
|
multiDimElemsPerThread[k] =
|
|
|
|
|
|
|
|
ceil<unsigned>(shape[k], shapePerCTA[k]) * sizePerThread[k];
|
|
|
|
|
|
|
|
elemsPerThread *= multiDimElemsPerThread[k];
|
|
|
|
|
|
|
|
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
|
|
|
|
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
|
|
|
|
for (unsigned blockOffset = 0;
|
|
|
|
for (unsigned blockOffset = 0;
|
|
|
|
blockOffset < ceil<unsigned>(shape[k], shapePerCTA[k]);
|
|
|
|
blockOffset < ceil<unsigned>(shape[k], shapePerCTA[k]);
|
|
|
@@ -574,34 +556,29 @@ public:
|
|
|
|
threadsPerWarp[k] +
|
|
|
|
threadsPerWarp[k] +
|
|
|
|
threadOffset * sizePerThread[k] + elemOffset);
|
|
|
|
threadOffset * sizePerThread[k] + elemOffset);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// step 3, add offset to base, and reorder the sequence of indices,
|
|
|
|
// step 3, add offset to base, and reorder the sequence of indices to
|
|
|
|
// to guarantee that elems in a same sizePerThread are adjacent in
|
|
|
|
// guarantee that elems in the same sizePerThread are adjacent in order
|
|
|
|
// order
|
|
|
|
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
|
|
|
|
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread);
|
|
|
|
SmallVector<Value>(rank));
|
|
|
|
unsigned accumSizePerThread =
|
|
|
|
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
|
|
|
|
std::accumulate(sizePerThread.begin(), sizePerThread.end(), 1,
|
|
|
|
|
|
|
|
std::multiplies<unsigned>());
|
|
|
|
|
|
|
|
SmallVector<unsigned> threadsPerDim(rank);
|
|
|
|
SmallVector<unsigned> threadsPerDim(rank);
|
|
|
|
for (unsigned k = 0; k < rank; ++k) {
|
|
|
|
for (unsigned k = 0; k < rank; ++k)
|
|
|
|
threadsPerDim[k] = ceil<unsigned>(shape[k], sizePerThread[k]);
|
|
|
|
threadsPerDim[k] = ceil<unsigned>(shape[k], sizePerThread[k]);
|
|
|
|
}
|
|
|
|
|
|
|
|
for (unsigned n = 0; n < elemsPerThread; ++n) {
|
|
|
|
for (unsigned n = 0; n < elemsPerThread; ++n) {
|
|
|
|
unsigned linearNanoTileId = n / accumSizePerThread;
|
|
|
|
unsigned linearNanoTileId = n / totalSizePerThread;
|
|
|
|
unsigned linearElemsInNanoTileId = n % accumSizePerThread;
|
|
|
|
unsigned linearNanoTileElemId = n % totalSizePerThread;
|
|
|
|
SmallVector<unsigned> multiDimNanoTileId =
|
|
|
|
SmallVector<unsigned> multiDimNanoTileId =
|
|
|
|
getMultiDimIndex<unsigned>(linearNanoTileId, threadsPerDim);
|
|
|
|
getMultiDimIndex<unsigned>(linearNanoTileId, threadsPerDim);
|
|
|
|
SmallVector<unsigned> multiElemsInNanoTileId =
|
|
|
|
SmallVector<unsigned> multiDimNanoTileElemId =
|
|
|
|
getMultiDimIndex<unsigned>(linearElemsInNanoTileId, sizePerThread);
|
|
|
|
getMultiDimIndex<unsigned>(linearNanoTileElemId, sizePerThread);
|
|
|
|
multiDimIdx[n].resize(rank);
|
|
|
|
|
|
|
|
for (unsigned k = 0; k < rank; ++k) {
|
|
|
|
for (unsigned k = 0; k < rank; ++k) {
|
|
|
|
unsigned reorderedMultiDimId =
|
|
|
|
unsigned reorderedMultiDimId =
|
|
|
|
multiDimNanoTileId[k] *
|
|
|
|
multiDimNanoTileId[k] *
|
|
|
|
(sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) +
|
|
|
|
(sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) +
|
|
|
|
multiElemsInNanoTileId[k];
|
|
|
|
multiDimNanoTileElemId[k];
|
|
|
|
multiDimIdx[n][k] =
|
|
|
|
multiDimIdx[n][k] =
|
|
|
|
add(multiDimBase[k],
|
|
|
|
add(multiDimBase[k], idx_val(offset[k][reorderedMultiDimId]));
|
|
|
|
createIndexAttrConstant(rewriter, loc, llvmIndexTy,
|
|
|
|
|
|
|
|
offset[k][reorderedMultiDimId]));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@@ -617,7 +594,7 @@ public:
|
|
|
|
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
|
|
|
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
|
|
|
size_t offset = allocation->getOffset(bufferId);
|
|
|
|
size_t offset = allocation->getOffset(bufferId);
|
|
|
|
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
|
|
|
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
|
|
|
Value offVal = createIndexAttrConstant(rewriter, loc, llvmIndexTy, offset);
|
|
|
|
Value offVal = idx_val(offset);
|
|
|
|
Value base = gep(ptrTy, smem, offVal);
|
|
|
|
Value base = gep(ptrTy, smem, offVal);
|
|
|
|
return base;
|
|
|
|
return base;
|
|
|
|
}
|
|
|
|
}
|
|
|
@@ -636,19 +613,17 @@ protected:
|
|
|
|
Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
|
|
|
Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
|
|
|
TypeConverter *typeConverter,
|
|
|
|
TypeConverter *typeConverter,
|
|
|
|
ConversionPatternRewriter &rewriter, Location loc) {
|
|
|
|
ConversionPatternRewriter &rewriter, Location loc) {
|
|
|
|
|
|
|
|
|
|
|
|
auto tensorTy = resType.cast<RankedTensorType>();
|
|
|
|
auto tensorTy = resType.cast<RankedTensorType>();
|
|
|
|
auto layout = tensorTy.getEncoding();
|
|
|
|
auto layout = tensorTy.getEncoding();
|
|
|
|
auto srcType = typeConverter->convertType(elemType);
|
|
|
|
auto srcType = typeConverter->convertType(elemType);
|
|
|
|
auto llSrc = bit_cast(srcType, constVal);
|
|
|
|
auto llSrc = bit_cast(srcType, constVal);
|
|
|
|
size_t numElemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
|
|
|
|
size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
|
|
|
|
llvm::SmallVector<Value, 4> elems(numElemsPerThread, llSrc);
|
|
|
|
llvm::SmallVector<Value, 4> elems(elemsPerThread, llSrc);
|
|
|
|
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
|
|
|
|
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
|
|
|
|
auto structTy =
|
|
|
|
auto structTy =
|
|
|
|
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
|
|
|
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
|
|
|
|
|
|
|
|
|
|
|
auto llStruct = getStructFromElements(loc, elems, rewriter, structTy);
|
|
|
|
return getStructFromElements(loc, elems, rewriter, structTy);
|
|
|
|
return llStruct;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
struct SplatOpConversion
|
|
|
|
struct SplatOpConversion
|
|
|
@@ -745,7 +720,7 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
|
|
|
assert(layout && "unexpected layout in getLayout");
|
|
|
|
assert(layout && "unexpected layout in getLayout");
|
|
|
|
auto shape = ty.getShape();
|
|
|
|
auto shape = ty.getShape();
|
|
|
|
unsigned valueElems = layout.getElemsPerThread(shape);
|
|
|
|
unsigned valueElems = layout.getElemsPerThread(shape);
|
|
|
|
return std::make_tuple(layout, valueElems);
|
|
|
|
return {layout, valueElems};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const {
|
|
|
|
unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const {
|
|
|
@@ -864,14 +839,14 @@ struct StoreOpConversion
|
|
|
|
llvm::SmallVector<std::string> asmArgs;
|
|
|
|
llvm::SmallVector<std::string> asmArgs;
|
|
|
|
|
|
|
|
|
|
|
|
Type valArgTy = IntegerType::get(ctx, width);
|
|
|
|
Type valArgTy = IntegerType::get(ctx, width);
|
|
|
|
auto wordTy = VectorType::get(wordNElems, valueElemTy);
|
|
|
|
auto wordTy = vec_ty(valueElemTy, wordNElems);
|
|
|
|
|
|
|
|
|
|
|
|
auto *asmArgList = ptxBuilder.newListOperand();
|
|
|
|
auto *asmArgList = ptxBuilder.newListOperand();
|
|
|
|
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
|
|
|
|
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
|
|
|
// llWord is a width-len composition
|
|
|
|
// llWord is a width-len composition
|
|
|
|
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
|
|
|
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
|
|
|
// Insert each value element to the composition
|
|
|
|
// Insert each value element to the composition
|
|
|
|
for (int elemIdx = 0; elemIdx < wordNElems; elemIdx++) {
|
|
|
|
for (int elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
|
|
|
|
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
|
|
|
|
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
|
|
|
|
assert(elemOffset < valueElems.size());
|
|
|
|
assert(elemOffset < valueElems.size());
|
|
|
|
Value elem = valueElems[elemOffset];
|
|
|
|
Value elem = valueElems[elemOffset];
|
|
|
@@ -894,10 +869,7 @@ struct StoreOpConversion
|
|
|
|
// TODO(Superjomn) Need to check masks before vectorize the load for all
|
|
|
|
// TODO(Superjomn) Need to check masks before vectorize the load for all
|
|
|
|
// the values share one predicate? Here assume all the mask values are
|
|
|
|
// the values share one predicate? Here assume all the mask values are
|
|
|
|
// the same.
|
|
|
|
// the same.
|
|
|
|
Value maskVal =
|
|
|
|
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
|
|
|
|
llMask ? maskElems[vecStart]
|
|
|
|
|
|
|
|
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
|
|
|
|
|
|
|
rewriter.getIntegerType(1), 1);
|
|
|
|
|
|
|
|
ptxStoreInstr.global().b(width).v(nWords);
|
|
|
|
ptxStoreInstr.global().b(width).v(nWords);
|
|
|
|
|
|
|
|
|
|
|
|
auto *asmAddr =
|
|
|
|
auto *asmAddr =
|
|
|
@@ -906,22 +878,12 @@ struct StoreOpConversion
|
|
|
|
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
|
|
|
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
|
|
|
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
|
|
|
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
|
|
|
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
|
|
|
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
|
|
|
for (int i = 0; i < nWords; i++)
|
|
|
|
for (int i = 0; i < nWords; ++i)
|
|
|
|
argTys.push_back(valArgTy);
|
|
|
|
argTys.push_back(valArgTy);
|
|
|
|
|
|
|
|
|
|
|
|
auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx);
|
|
|
|
auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx);
|
|
|
|
|
|
|
|
|
|
|
|
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
|
|
|
ptxBuilder.launch(rewriter, loc, ASMReturnTy);
|
|
|
|
loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands
|
|
|
|
|
|
|
|
ptxBuilder.dump(), // asm_string
|
|
|
|
|
|
|
|
ptxBuilder.getConstraints(), // constraints
|
|
|
|
|
|
|
|
// TODO(Superjomn) determine the side effect.
|
|
|
|
|
|
|
|
true, // has_side_effects
|
|
|
|
|
|
|
|
false, // is_align_stack
|
|
|
|
|
|
|
|
LLVM::AsmDialectAttr::get(ctx,
|
|
|
|
|
|
|
|
LLVM::AsmDialect::AD_ATT), // asm_dialect
|
|
|
|
|
|
|
|
ArrayAttr::get(ctx, {}) // operand_attrs
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
|
|
return success();
|
|
|
@@ -1183,10 +1145,7 @@ struct LoadOpConversion
|
|
|
|
// TODO(Superjomn) Need to check masks before vectorize the load for all
|
|
|
|
// TODO(Superjomn) Need to check masks before vectorize the load for all
|
|
|
|
// the values share one predicate? Here assume all the mask values are
|
|
|
|
// the values share one predicate? Here assume all the mask values are
|
|
|
|
// the same.
|
|
|
|
// the same.
|
|
|
|
Value pred =
|
|
|
|
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
|
|
|
|
mask ? maskElems[vecStart]
|
|
|
|
|
|
|
|
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
|
|
|
|
|
|
|
rewriter.getIntegerType(1), 1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const std::string readConstraint =
|
|
|
|
const std::string readConstraint =
|
|
|
|
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
|
|
|
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
|
|
@@ -1195,7 +1154,7 @@ struct LoadOpConversion
|
|
|
|
|
|
|
|
|
|
|
|
// prepare asm operands
|
|
|
|
// prepare asm operands
|
|
|
|
auto *dstsOpr = ptxBuilder.newListOperand();
|
|
|
|
auto *dstsOpr = ptxBuilder.newListOperand();
|
|
|
|
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
|
|
|
|
for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
|
|
|
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
|
|
|
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
|
|
|
dstsOpr->listAppend(opr);
|
|
|
|
dstsOpr->listAppend(opr);
|
|
|
|
}
|
|
|
|
}
|
|
|
@@ -1228,7 +1187,7 @@ struct LoadOpConversion
|
|
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> others;
|
|
|
|
SmallVector<Value> others;
|
|
|
|
if (other) {
|
|
|
|
if (other) {
|
|
|
|
for (size_t ii = 0; ii < nWords; ii++) {
|
|
|
|
for (size_t ii = 0; ii < nWords; ++ii) {
|
|
|
|
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
|
|
|
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
|
|
|
mov.o("u", width);
|
|
|
|
mov.o("u", width);
|
|
|
|
|
|
|
|
|
|
|
@@ -1236,7 +1195,7 @@ struct LoadOpConversion
|
|
|
|
|
|
|
|
|
|
|
|
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
|
|
|
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
|
|
|
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
|
|
|
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
|
|
|
for (size_t s = 0; s < size; s++) {
|
|
|
|
for (size_t s = 0; s < size; ++s) {
|
|
|
|
Value falseVal = otherElems[vecStart + ii * size + s];
|
|
|
|
Value falseVal = otherElems[vecStart + ii * size + s];
|
|
|
|
Value sVal = createIndexAttrConstant(
|
|
|
|
Value sVal = createIndexAttrConstant(
|
|
|
|
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
|
|
|
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
|
|
@@ -1267,20 +1226,13 @@ struct LoadOpConversion
|
|
|
|
// TODO: if (has_l2_evict_policy)
|
|
|
|
// TODO: if (has_l2_evict_policy)
|
|
|
|
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
|
|
|
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
|
|
|
LLVM::AsmDialect::AD_ATT);
|
|
|
|
LLVM::AsmDialect::AD_ATT);
|
|
|
|
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
|
|
|
|
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
|
|
|
|
loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(),
|
|
|
|
|
|
|
|
/*asm_string=*/ptxBuilder.dump(),
|
|
|
|
|
|
|
|
/*constraints=*/ptxBuilder.getConstraints(),
|
|
|
|
|
|
|
|
/*has_side_effects=*/true,
|
|
|
|
|
|
|
|
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
|
|
|
|
|
|
|
|
/*operand_attrs=*/ArrayAttr());
|
|
|
|
|
|
|
|
Value ret = inlineAsmOp.getResult(0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// ---
|
|
|
|
// ---
|
|
|
|
// extract and store return values
|
|
|
|
// extract and store return values
|
|
|
|
// ---
|
|
|
|
// ---
|
|
|
|
SmallVector<Value> rets;
|
|
|
|
SmallVector<Value> rets;
|
|
|
|
for (unsigned int ii = 0; ii < nWords; ii++) {
|
|
|
|
for (unsigned int ii = 0; ii < nWords; ++ii) {
|
|
|
|
Value curr;
|
|
|
|
Value curr;
|
|
|
|
if (retTy.isa<LLVM::LLVMStructType>()) {
|
|
|
|
if (retTy.isa<LLVM::LLVMStructType>()) {
|
|
|
|
curr = extract_val(IntegerType::get(getContext(), width), ret,
|
|
|
|
curr = extract_val(IntegerType::get(getContext(), width), ret,
|
|
|
@@ -1293,8 +1245,8 @@ struct LoadOpConversion
|
|
|
|
curr);
|
|
|
|
curr);
|
|
|
|
rets.push_back(curr);
|
|
|
|
rets.push_back(curr);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
int tmp = (width / valueElemNbits);
|
|
|
|
int tmp = width / valueElemNbits;
|
|
|
|
for (size_t ii = 0; ii < vec; ii++) {
|
|
|
|
for (size_t ii = 0; ii < vec; ++ii) {
|
|
|
|
Value vecIdx = createIndexAttrConstant(
|
|
|
|
Value vecIdx = createIndexAttrConstant(
|
|
|
|
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
|
|
|
|
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
|
|
|
|
Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx);
|
|
|
|
Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx);
|
|
|
@@ -1480,6 +1432,7 @@ public:
|
|
|
|
(!dstLayout.isa<BlockedEncodingAttr>() &&
|
|
|
|
(!dstLayout.isa<BlockedEncodingAttr>() &&
|
|
|
|
!dstLayout.isa<MmaEncodingAttr>())) {
|
|
|
|
!dstLayout.isa<MmaEncodingAttr>())) {
|
|
|
|
// TODO: to be implemented
|
|
|
|
// TODO: to be implemented
|
|
|
|
|
|
|
|
llvm::errs() << "Unsupported ConvertLayout found";
|
|
|
|
return failure();
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
|
|
|
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
|
|
@@ -1494,11 +1447,11 @@ public:
|
|
|
|
SmallVector<unsigned> outNumCTAsEachRep(rank);
|
|
|
|
SmallVector<unsigned> outNumCTAsEachRep(rank);
|
|
|
|
SmallVector<unsigned> inNumCTAs(rank);
|
|
|
|
SmallVector<unsigned> inNumCTAs(rank);
|
|
|
|
SmallVector<unsigned> outNumCTAs(rank);
|
|
|
|
SmallVector<unsigned> outNumCTAs(rank);
|
|
|
|
|
|
|
|
auto srcShapePerCTA = getShapePerCTA(srcLayout);
|
|
|
|
|
|
|
|
auto dstShapePerCTA = getShapePerCTA(dstLayout);
|
|
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
|
|
unsigned inPerCTA =
|
|
|
|
unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
|
|
|
|
std::min(unsigned(shape[d]), getShapePerCTA(srcLayout, d));
|
|
|
|
unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
|
|
|
|
unsigned outPerCTA =
|
|
|
|
|
|
|
|
std::min(unsigned(shape[d]), getShapePerCTA(dstLayout, d));
|
|
|
|
|
|
|
|
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
|
|
|
|
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
|
|
|
|
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
|
|
|
|
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
|
|
|
|
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
|
|
|
|
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
|
|
|
@@ -1579,9 +1532,8 @@ private:
|
|
|
|
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
|
|
|
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
|
|
|
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
|
|
|
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
|
|
|
SmallVector<unsigned> numCTAs(rank);
|
|
|
|
SmallVector<unsigned> numCTAs(rank);
|
|
|
|
SmallVector<unsigned> shapePerCTA(rank);
|
|
|
|
auto shapePerCTA = getShapePerCTA(layout);
|
|
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
|
|
shapePerCTA[d] = getShapePerCTA(layout, d);
|
|
|
|
|
|
|
|
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
|
|
|
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
|
|
|
|
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
|
|
|
@@ -1603,16 +1555,15 @@ private:
|
|
|
|
Value warpSize = createIndexAttrConstant(
|
|
|
|
Value warpSize = createIndexAttrConstant(
|
|
|
|
rewriter, loc, this->getTypeConverter()->getIndexType(), 32);
|
|
|
|
rewriter, loc, this->getTypeConverter()->getIndexType(), 32);
|
|
|
|
Value laneId = rewriter.create<LLVM::URemOp>(loc, threadId, warpSize);
|
|
|
|
Value laneId = rewriter.create<LLVM::URemOp>(loc, threadId, warpSize);
|
|
|
|
Value fourVal = createIndexConst(rewriter, loc, 4);
|
|
|
|
Value fourVal = idx_val(4);
|
|
|
|
mmaGrpId = rewriter.create<LLVM::UDivOp>(loc, laneId, fourVal);
|
|
|
|
mmaGrpId = rewriter.create<LLVM::UDivOp>(loc, laneId, fourVal);
|
|
|
|
mmaGrpIdP8 = rewriter.create<LLVM::AddOp>(
|
|
|
|
mmaGrpIdP8 = rewriter.create<LLVM::AddOp>(loc, mmaGrpId, idx_val(8));
|
|
|
|
loc, mmaGrpId, createIndexConst(rewriter, loc, 8));
|
|
|
|
|
|
|
|
Value mmaThreadIdInGrp =
|
|
|
|
Value mmaThreadIdInGrp =
|
|
|
|
rewriter.create<LLVM::URemOp>(loc, laneId, fourVal);
|
|
|
|
rewriter.create<LLVM::URemOp>(loc, laneId, fourVal);
|
|
|
|
mmaThreadIdInGrpM2 = rewriter.create<LLVM::MulOp>(
|
|
|
|
mmaThreadIdInGrpM2 =
|
|
|
|
loc, mmaThreadIdInGrp, createIndexConst(rewriter, loc, 2));
|
|
|
|
rewriter.create<LLVM::MulOp>(loc, mmaThreadIdInGrp, idx_val(2));
|
|
|
|
mmaThreadIdInGrpM2P1 = rewriter.create<LLVM::AddOp>(
|
|
|
|
mmaThreadIdInGrpM2P1 =
|
|
|
|
loc, mmaThreadIdInGrpM2, createIndexConst(rewriter, loc, 1));
|
|
|
|
rewriter.create<LLVM::AddOp>(loc, mmaThreadIdInGrpM2, idx_val(1));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
|
|
|
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
|
|
|
auto multiDimCTAInRepId =
|
|
|
|
auto multiDimCTAInRepId =
|
|
|
@@ -1654,7 +1605,7 @@ private:
|
|
|
|
reorder<unsigned>(paddedRepShape, outOrd));
|
|
|
|
reorder<unsigned>(paddedRepShape, outOrd));
|
|
|
|
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
|
|
|
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
|
|
|
Value ptr = gep(elemPtrTy, smemBase, offset);
|
|
|
|
Value ptr = gep(elemPtrTy, smemBase, offset);
|
|
|
|
auto vecTy = VectorType::get(vec, llvmElemTy);
|
|
|
|
auto vecTy = vec_ty(llvmElemTy, vec);
|
|
|
|
ptr = bit_cast(LLVM::LLVMPointerType::get(vecTy, 3), ptr);
|
|
|
|
ptr = bit_cast(LLVM::LLVMPointerType::get(vecTy, 3), ptr);
|
|
|
|
if (stNotRd) {
|
|
|
|
if (stNotRd) {
|
|
|
|
Value valVec = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
|
|
|
Value valVec = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
|
|
@@ -1665,9 +1616,9 @@ private:
|
|
|
|
vecTy, valVec,
|
|
|
|
vecTy, valVec,
|
|
|
|
vals[elemId + linearCTAId * accumSizePerThread + v], vVal);
|
|
|
|
vals[elemId + linearCTAId * accumSizePerThread + v], vVal);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rewriter.create<LLVM::StoreOp>(loc, valVec, ptr);
|
|
|
|
store(valVec, ptr);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
Value valVec = rewriter.create<LLVM::LoadOp>(loc, ptr);
|
|
|
|
Value valVec = load(ptr);
|
|
|
|
for (unsigned v = 0; v < vec; ++v) {
|
|
|
|
for (unsigned v = 0; v < vec; ++v) {
|
|
|
|
Value vVal = createIndexAttrConstant(
|
|
|
|
Value vVal = createIndexAttrConstant(
|
|
|
|
rewriter, loc, getTypeConverter()->getIndexType(), v);
|
|
|
|
rewriter, loc, getTypeConverter()->getIndexType(), v);
|
|
|
@@ -1682,6 +1633,7 @@ private:
|
|
|
|
|
|
|
|
|
|
|
|
/// ====================== dot codegen begin ==========================
|
|
|
|
/// ====================== dot codegen begin ==========================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Data loader for mma.16816 instruction.
|
|
|
|
class MMA16816SmemLoader {
|
|
|
|
class MMA16816SmemLoader {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, int kOrder,
|
|
|
|
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, int kOrder,
|
|
|
@@ -1689,8 +1641,10 @@ public:
|
|
|
|
ArrayRef<int> matShape, int perPhase, int maxPhase,
|
|
|
|
ArrayRef<int> matShape, int perPhase, int maxPhase,
|
|
|
|
int elemBytes, ConversionPatternRewriter &rewriter,
|
|
|
|
int elemBytes, ConversionPatternRewriter &rewriter,
|
|
|
|
TypeConverter *typeConverter, const Location &loc)
|
|
|
|
TypeConverter *typeConverter, const Location &loc)
|
|
|
|
: wpt(wpt), order(order), kOrder(kOrder), tileShape(tileShape),
|
|
|
|
: wpt(wpt), order(order.begin(), order.end()), kOrder(kOrder),
|
|
|
|
instrShape(instrShape), matShape(matShape), perPhase(perPhase),
|
|
|
|
tileShape(tileShape.begin(), tileShape.end()),
|
|
|
|
|
|
|
|
instrShape(instrShape.begin(), instrShape.end()),
|
|
|
|
|
|
|
|
matShape(matShape.begin(), matShape.end()), perPhase(perPhase),
|
|
|
|
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter),
|
|
|
|
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter),
|
|
|
|
typeConverter(typeConverter), loc(loc), ctx(rewriter.getContext()) {
|
|
|
|
typeConverter(typeConverter), loc(loc), ctx(rewriter.getContext()) {
|
|
|
|
cMatShape = matShape[order[0]];
|
|
|
|
cMatShape = matShape[order[0]];
|
|
|
@@ -1722,7 +1676,7 @@ public:
|
|
|
|
loadStrideInMat[kOrder] =
|
|
|
|
loadStrideInMat[kOrder] =
|
|
|
|
2; // instrShape[kOrder] / matShape[kOrder], always 2
|
|
|
|
2; // instrShape[kOrder] / matShape[kOrder], always 2
|
|
|
|
loadStrideInMat[kOrder ^ 1] =
|
|
|
|
loadStrideInMat[kOrder ^ 1] =
|
|
|
|
wpt * (instrShape[order[1]] / matShape[order[1]]);
|
|
|
|
wpt * (instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]);
|
|
|
|
|
|
|
|
|
|
|
|
pLoadStrideInMat = loadStrideInMat[order[0]];
|
|
|
|
pLoadStrideInMat = loadStrideInMat[order[0]];
|
|
|
|
sMatStride =
|
|
|
|
sMatStride =
|
|
|
@@ -1753,8 +1707,6 @@ public:
|
|
|
|
// Compute the offset to the matrix this thread(indexed by warpOff and lane)
|
|
|
|
// Compute the offset to the matrix this thread(indexed by warpOff and lane)
|
|
|
|
// mapped to.
|
|
|
|
// mapped to.
|
|
|
|
SmallVector<Value> computeLdmatrixMatOffs(Value warpId, Value lane) {
|
|
|
|
SmallVector<Value> computeLdmatrixMatOffs(Value warpId, Value lane) {
|
|
|
|
MLIRContext *ctx = warpId.getContext();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 4x4 matrices
|
|
|
|
// 4x4 matrices
|
|
|
|
Value c = urem(lane, i32_val(8));
|
|
|
|
Value c = urem(lane, i32_val(8));
|
|
|
|
Value s = udiv(lane, i32_val(8)); // sub-warp-id
|
|
|
|
Value s = udiv(lane, i32_val(8)); // sub-warp-id
|
|
|
@@ -1895,6 +1847,7 @@ public:
|
|
|
|
int k = matIdx[kOrder];
|
|
|
|
int k = matIdx[kOrder];
|
|
|
|
|
|
|
|
|
|
|
|
int ptrIdx{-1};
|
|
|
|
int ptrIdx{-1};
|
|
|
|
|
|
|
|
|
|
|
|
if (canUseLdmatrix)
|
|
|
|
if (canUseLdmatrix)
|
|
|
|
ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]);
|
|
|
|
ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]);
|
|
|
|
else if (elemBytes == 4 && needTrans) // tf32 & trans
|
|
|
|
else if (elemBytes == 4 && needTrans) // tf32 & trans
|
|
|
@@ -1904,7 +1857,9 @@ public:
|
|
|
|
else
|
|
|
|
else
|
|
|
|
llvm::report_fatal_error("unsupported mma type found");
|
|
|
|
llvm::report_fatal_error("unsupported mma type found");
|
|
|
|
|
|
|
|
|
|
|
|
// prefetch logic removed here.
|
|
|
|
// The main difference with the original triton code is we removed the
|
|
|
|
|
|
|
|
// prefetch-related logic here for the upstream optimizer phase should take
|
|
|
|
|
|
|
|
// care with it, and that is transparent in dot conversion.
|
|
|
|
auto getPtr = [&](int idx) { return ptrs[idx]; };
|
|
|
|
auto getPtr = [&](int idx) { return ptrs[idx]; };
|
|
|
|
|
|
|
|
|
|
|
|
Value ptr = getPtr(ptrIdx);
|
|
|
|
Value ptr = getPtr(ptrIdx);
|
|
|
@@ -1915,11 +1870,8 @@ public:
|
|
|
|
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
|
|
|
|
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
|
|
|
|
PTXBuilder builder;
|
|
|
|
PTXBuilder builder;
|
|
|
|
|
|
|
|
|
|
|
|
auto resArgs = builder.newListOperand();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a thread.
|
|
|
|
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a thread.
|
|
|
|
for (int i = 0; i < 4; i++)
|
|
|
|
auto resArgs = builder.newListOperand(4, "=r");
|
|
|
|
resArgs->listAppend(builder.newOperand("=r"));
|
|
|
|
|
|
|
|
auto addrArg = builder.newAddrOperand(ptr, "r", sOffset);
|
|
|
|
auto addrArg = builder.newAddrOperand(ptr, "r", sOffset);
|
|
|
|
|
|
|
|
|
|
|
|
auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4")
|
|
|
|
auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4")
|
|
|
@@ -1927,46 +1879,127 @@ public:
|
|
|
|
.o("shared.b16");
|
|
|
|
.o("shared.b16");
|
|
|
|
ldmatrix(resArgs, addrArg);
|
|
|
|
ldmatrix(resArgs, addrArg);
|
|
|
|
|
|
|
|
|
|
|
|
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
|
|
|
// The result type is 4xi32, each i32 is composed of 2xf16
|
|
|
|
loc, ldmatrixRetTy, builder.getAllMLIRArgs(), // operands
|
|
|
|
// elements(adjacent two columns in a row)
|
|
|
|
builder.dump(), // asm_string
|
|
|
|
Value resV4 = builder.launch(rewriter, loc, ldmatrixRetTy);
|
|
|
|
builder.getConstraints(), // constraints
|
|
|
|
|
|
|
|
true, // has_side_effects
|
|
|
|
|
|
|
|
false, // is_align_stack
|
|
|
|
|
|
|
|
LLVM::AsmDialectAttr::get(ctx,
|
|
|
|
|
|
|
|
LLVM::AsmDialect::AD_ATT), // asm_dialect
|
|
|
|
|
|
|
|
ArrayAttr::get(ctx, {}) // operand_attrs
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto getIntAttr = [&](int v) {
|
|
|
|
auto getIntAttr = [&](int v) {
|
|
|
|
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty(), v)});
|
|
|
|
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
Value resV4 = inlineAsm.getRes(); // 4xi32, each is composed of 2xf16
|
|
|
|
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
|
|
|
|
// elements(adjacent columns in a row)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Type fp16x2Ty = VectorType::get({2}, type::f16Ty(ctx));
|
|
|
|
return {extract_val(fp16x2Ty, resV4, getIntAttr(0)),
|
|
|
|
|
|
|
|
|
|
|
|
return std::make_tuple(extract_val(fp16x2Ty, resV4, getIntAttr(0)),
|
|
|
|
|
|
|
|
extract_val(fp16x2Ty, resV4, getIntAttr(1)),
|
|
|
|
extract_val(fp16x2Ty, resV4, getIntAttr(1)),
|
|
|
|
extract_val(fp16x2Ty, resV4, getIntAttr(2)),
|
|
|
|
extract_val(fp16x2Ty, resV4, getIntAttr(2)),
|
|
|
|
extract_val(fp16x2Ty, resV4, getIntAttr(3)));
|
|
|
|
extract_val(fp16x2Ty, resV4, getIntAttr(3))};
|
|
|
|
} else if (elemBytes == 4 &&
|
|
|
|
} else if (elemBytes == 4 &&
|
|
|
|
needTrans) { // Use lds.32 to load tf32 matrices
|
|
|
|
needTrans) { // Use lds.32 to load tf32 matrices
|
|
|
|
assert(false && "Not implemented yet");
|
|
|
|
Value ptr2 = getPtr(ptrIdx + 1);
|
|
|
|
} else if (elemBytes == 1 && needTrans) {
|
|
|
|
assert(sMatStride == 1);
|
|
|
|
assert(false && "Not implemented yet");
|
|
|
|
int sOffsetElem =
|
|
|
|
|
|
|
|
matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride;
|
|
|
|
|
|
|
|
int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Value elems[4];
|
|
|
|
|
|
|
|
Type elemTy = type::f32Ty(ctx);
|
|
|
|
|
|
|
|
if (kOrder == 1) {
|
|
|
|
|
|
|
|
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
|
|
|
|
|
|
|
|
elems[1] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
|
|
|
|
|
|
|
|
elems[2] =
|
|
|
|
|
|
|
|
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
|
|
|
|
|
|
|
elems[3] =
|
|
|
|
|
|
|
|
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
|
|
|
|
|
|
|
|
elems[2] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
|
|
|
|
|
|
|
|
elems[1] =
|
|
|
|
|
|
|
|
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
|
|
|
|
|
|
|
elems[3] =
|
|
|
|
|
|
|
|
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return std::make_tuple(Value{}, Value{}, Value{}, Value{});
|
|
|
|
|
|
|
|
|
|
|
|
return {elems[0], elems[1], elems[2], elems[3]};
|
|
|
|
|
|
|
|
} else if (elemBytes == 1 && needTrans) {
|
|
|
|
|
|
|
|
std::array<std::array<Value, 4>, 2> ptrs;
|
|
|
|
|
|
|
|
ptrs[0] = {
|
|
|
|
|
|
|
|
getPtr(ptrIdx),
|
|
|
|
|
|
|
|
getPtr(ptrIdx + 1),
|
|
|
|
|
|
|
|
getPtr(ptrIdx + 2),
|
|
|
|
|
|
|
|
getPtr(ptrIdx + 3),
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ptrs[1] = {
|
|
|
|
|
|
|
|
getPtr(ptrIdx + 4),
|
|
|
|
|
|
|
|
getPtr(ptrIdx + 5),
|
|
|
|
|
|
|
|
getPtr(ptrIdx + 6),
|
|
|
|
|
|
|
|
getPtr(ptrIdx + 7),
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert(sMatStride == 1);
|
|
|
|
|
|
|
|
int sOffsetElem =
|
|
|
|
|
|
|
|
matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride;
|
|
|
|
|
|
|
|
int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::array<Value, 4> i8v4Elems;
|
|
|
|
|
|
|
|
std::array<Value, 4> i32Elems;
|
|
|
|
|
|
|
|
i8v4Elems.fill(
|
|
|
|
|
|
|
|
rewriter.create<LLVM::UndefOp>(loc, vec_ty(type::i8Ty(ctx), 4)));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Value i8Elems[4][4];
|
|
|
|
|
|
|
|
Type elemTy = type::i8Ty(ctx);
|
|
|
|
|
|
|
|
if (kOrder == 1) {
|
|
|
|
|
|
|
|
Value offset = i32_val(sOffsetElem);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < 2; ++i)
|
|
|
|
|
|
|
|
for (int j = 0; j < 4; ++j)
|
|
|
|
|
|
|
|
i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], offset));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
offset = i32_val(sOffsetElem + sOffsetArrElem);
|
|
|
|
|
|
|
|
for (int i = 2; i < 4; ++i)
|
|
|
|
|
|
|
|
for (int j = 0; j < 4; ++j)
|
|
|
|
|
|
|
|
i8Elems[i][j] = load(gep(elemTy, ptrs[i - 2][j], offset));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int m = 0; m < 4; ++m) {
|
|
|
|
|
|
|
|
for (int e = 0; e < 4; ++e)
|
|
|
|
|
|
|
|
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
|
|
|
|
|
|
|
|
i8Elems[m][e], i32_val(e));
|
|
|
|
|
|
|
|
i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else { // k first
|
|
|
|
|
|
|
|
Value offset = i32_val(sOffsetElem);
|
|
|
|
|
|
|
|
for (int j = 0; j < 4; ++j)
|
|
|
|
|
|
|
|
i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], offset));
|
|
|
|
|
|
|
|
for (int j = 0; j < 4; ++j)
|
|
|
|
|
|
|
|
i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], offset));
|
|
|
|
|
|
|
|
offset = i32_val(sOffsetElem + sOffsetArrElem);
|
|
|
|
|
|
|
|
for (int j = 0; j < 4; ++j)
|
|
|
|
|
|
|
|
i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], offset));
|
|
|
|
|
|
|
|
for (int j = 0; j < 4; ++j)
|
|
|
|
|
|
|
|
i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], offset));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int m = 0; m < 4; ++m) {
|
|
|
|
|
|
|
|
for (int e = 0; e < 4; ++e)
|
|
|
|
|
|
|
|
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
|
|
|
|
|
|
|
|
i8Elems[m][e], i32_val(e));
|
|
|
|
|
|
|
|
i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {i32Elems[0], i32Elems[1], i32Elems[2], i32Elems[3]};
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert(false && "Invalid smem load");
|
|
|
|
|
|
|
|
return {Value{}, Value{}, Value{}, Value{}};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
int wpt;
|
|
|
|
int wpt;
|
|
|
|
ArrayRef<uint32_t> order;
|
|
|
|
SmallVector<uint32_t> order;
|
|
|
|
int kOrder;
|
|
|
|
int kOrder;
|
|
|
|
ArrayRef<int64_t> tileShape;
|
|
|
|
SmallVector<int64_t> tileShape;
|
|
|
|
ArrayRef<int> instrShape;
|
|
|
|
SmallVector<int> instrShape;
|
|
|
|
ArrayRef<int> matShape;
|
|
|
|
SmallVector<int> matShape;
|
|
|
|
int perPhase;
|
|
|
|
int perPhase;
|
|
|
|
int maxPhase;
|
|
|
|
int maxPhase;
|
|
|
|
int elemBytes;
|
|
|
|
int elemBytes;
|
|
|
@@ -2157,8 +2190,8 @@ struct DotOpConversionHelper {
|
|
|
|
// The type of a matrix that loaded by either a ldmatrix or composed lds.
|
|
|
|
// The type of a matrix that loaded by either a ldmatrix or composed lds.
|
|
|
|
Type getMatType() const {
|
|
|
|
Type getMatType() const {
|
|
|
|
Type fp32Ty = type::f32Ty(ctx);
|
|
|
|
Type fp32Ty = type::f32Ty(ctx);
|
|
|
|
Type fp16x2Ty = VectorType::get({2}, type::f16Ty(ctx));
|
|
|
|
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
|
|
|
|
Type bf16x2Ty = VectorType::get({2}, type::bf16Ty(ctx));
|
|
|
|
Type bf16x2Ty = vec_ty(type::bf16Ty(ctx), 2);
|
|
|
|
// floating point types
|
|
|
|
// floating point types
|
|
|
|
Type fp16x2Pack4Ty =
|
|
|
|
Type fp16x2Pack4Ty =
|
|
|
|
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp16x2Ty));
|
|
|
|
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp16x2Ty));
|
|
|
@@ -2167,7 +2200,7 @@ struct DotOpConversionHelper {
|
|
|
|
Type fp32Pack4Ty =
|
|
|
|
Type fp32Pack4Ty =
|
|
|
|
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp32Ty));
|
|
|
|
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp32Ty));
|
|
|
|
// integer types
|
|
|
|
// integer types
|
|
|
|
Type i8x4Ty = VectorType::get({4}, type::i8Ty(ctx));
|
|
|
|
Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4);
|
|
|
|
Type i8x4Pack4Ty =
|
|
|
|
Type i8x4Pack4Ty =
|
|
|
|
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i8x4Ty));
|
|
|
|
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i8x4Ty));
|
|
|
|
Type i32Pack4Ty = LLVM::LLVMStructType::getLiteral(
|
|
|
|
Type i32Pack4Ty = LLVM::LLVMStructType::getLiteral(
|
|
|
@@ -2189,6 +2222,23 @@ struct DotOpConversionHelper {
|
|
|
|
return Type{};
|
|
|
|
return Type{};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Type getLoadElemTy() {
|
|
|
|
|
|
|
|
switch (mmaType) {
|
|
|
|
|
|
|
|
case TensorCoreType::FP32_FP16_FP16_FP32:
|
|
|
|
|
|
|
|
return vec_ty(type::f16Ty(ctx), 2);
|
|
|
|
|
|
|
|
case TensorCoreType::FP32_BF16_BF16_FP32:
|
|
|
|
|
|
|
|
return vec_ty(type::bf16Ty(ctx), 2);
|
|
|
|
|
|
|
|
case TensorCoreType::FP32_TF32_TF32_FP32:
|
|
|
|
|
|
|
|
return type::f32Ty(ctx);
|
|
|
|
|
|
|
|
case TensorCoreType::INT32_INT8_INT8_INT32:
|
|
|
|
|
|
|
|
return type::i32Ty(ctx);
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
|
|
|
llvm::report_fatal_error("Unsupported mma type found");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return Type{};
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Type getMmaRetType() const {
|
|
|
|
Type getMmaRetType() const {
|
|
|
|
Type fp32Ty = type::f32Ty(ctx);
|
|
|
|
Type fp32Ty = type::f32Ty(ctx);
|
|
|
|
Type i32Ty = type::i32Ty(ctx);
|
|
|
|
Type i32Ty = type::i32Ty(ctx);
|
|
|
@@ -2375,9 +2425,10 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
|
|
|
const int numRepN = std::max<int>(dShape[1] / (wpt[1] * mmaInstrN), 1);
|
|
|
|
const int numRepN = std::max<int>(dShape[1] / (wpt[1] * mmaInstrN), 1);
|
|
|
|
const int numRepK = std::max<int>(NK / mmaInstrK, 1);
|
|
|
|
const int numRepK = std::max<int>(NK / mmaInstrK, 1);
|
|
|
|
|
|
|
|
|
|
|
|
Value head = getThreadId(rewriter, loc);
|
|
|
|
Value _32 = i32_val(32);
|
|
|
|
Value lane = urem(head, i32_val(32));
|
|
|
|
Value thread = getThreadId(rewriter, loc);
|
|
|
|
Value warp = udiv(head, i32_val(32));
|
|
|
|
Value lane = urem(thread, _32);
|
|
|
|
|
|
|
|
Value warp = udiv(thread, _32);
|
|
|
|
Value warpMN = udiv(warp, i32_val(wpt[0]));
|
|
|
|
Value warpMN = udiv(warp, i32_val(wpt[0]));
|
|
|
|
Value warpM = urem(warp, i32_val(wpt[0]));
|
|
|
|
Value warpM = urem(warp, i32_val(wpt[0]));
|
|
|
|
Value warpN = urem(warpMN, i32_val(wpt[1]));
|
|
|
|
Value warpN = urem(warpMN, i32_val(wpt[1]));
|
|
|
@@ -2389,7 +2440,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
|
|
|
std::map<std::pair<unsigned, unsigned>, Value> hb;
|
|
|
|
std::map<std::pair<unsigned, unsigned>, Value> hb;
|
|
|
|
|
|
|
|
|
|
|
|
// the original register_lds2, but discard the prefetch logic.
|
|
|
|
// the original register_lds2, but discard the prefetch logic.
|
|
|
|
auto ld2 = [&](decltype(ha) &vals, int mn, int k, Value val) {
|
|
|
|
auto ld2 = [](decltype(ha) &vals, int mn, int k, Value val) {
|
|
|
|
vals[{mn, k}] = val;
|
|
|
|
vals[{mn, k}] = val;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@@ -2405,6 +2456,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
|
|
|
const int perPhase = sharedLayout.getPerPhase();
|
|
|
|
const int perPhase = sharedLayout.getPerPhase();
|
|
|
|
const int maxPhase = sharedLayout.getMaxPhase();
|
|
|
|
const int maxPhase = sharedLayout.getMaxPhase();
|
|
|
|
const int elemBytes = tensorTy.getElementTypeBitWidth() / 8;
|
|
|
|
const int elemBytes = tensorTy.getElementTypeBitWidth() / 8;
|
|
|
|
|
|
|
|
auto order = sharedLayout.getOrder();
|
|
|
|
|
|
|
|
|
|
|
|
MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder,
|
|
|
|
MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder,
|
|
|
|
tensorTy.getShape() /*tileShape*/, instrShape,
|
|
|
|
tensorTy.getShape() /*tileShape*/, instrShape,
|
|
|
@@ -2417,34 +2469,56 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
|
|
|
|
|
|
|
|
|
|
|
Type smemPtrTy = helper.getShemPtrTy();
|
|
|
|
Type smemPtrTy = helper.getShemPtrTy();
|
|
|
|
auto smemBase = getSharedMemoryBase(loc, rewriter, tensor);
|
|
|
|
auto smemBase = getSharedMemoryBase(loc, rewriter, tensor);
|
|
|
|
for (int i = 0; i < numPtrs; i++) {
|
|
|
|
for (int i = 0; i < numPtrs; ++i) {
|
|
|
|
ptrs[i] = bit_cast(
|
|
|
|
ptrs[i] = bit_cast(
|
|
|
|
smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]})));
|
|
|
|
smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]})));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool needTrans = kOrder != order[0];
|
|
|
|
|
|
|
|
|
|
|
|
// (a, b) is the coordinate.
|
|
|
|
// (a, b) is the coordinate.
|
|
|
|
auto load = [&, loader, ptrs, offs](int a, int b) {
|
|
|
|
auto load = [&, loader, ptrs, offs, needTrans](int a, int b) {
|
|
|
|
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
|
|
|
|
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
|
|
|
|
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
|
|
|
|
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
|
|
|
|
ptrs, helper.getMatType(), helper.getShemPtrTy());
|
|
|
|
ptrs, helper.getMatType(), helper.getShemPtrTy());
|
|
|
|
|
|
|
|
if (!needTrans) {
|
|
|
|
ld2(vals, a, b, ha0);
|
|
|
|
ld2(vals, a, b, ha0);
|
|
|
|
ld2(vals, a + 1, b, ha1);
|
|
|
|
ld2(vals, a + 1, b, ha1);
|
|
|
|
ld2(vals, a, b + 1, ha2);
|
|
|
|
ld2(vals, a, b + 1, ha2);
|
|
|
|
ld2(vals, a + 1, b + 1, ha3);
|
|
|
|
ld2(vals, a + 1, b + 1, ha3);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
ld2(vals, a, b, ha0);
|
|
|
|
|
|
|
|
ld2(vals, a + 1, b, ha2);
|
|
|
|
|
|
|
|
ld2(vals, a, b + 1, ha1);
|
|
|
|
|
|
|
|
ld2(vals, a + 1, b + 1, ha3);
|
|
|
|
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
return load;
|
|
|
|
return load;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
std::function<void(int, int)> loadA = getLoadMatrixFn(
|
|
|
|
std::function<void(int, int)> loadA;
|
|
|
|
A, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, 1 /*kOrder*/,
|
|
|
|
|
|
|
|
{mmaInstrM, mmaInstrK} /*instrShpae*/,
|
|
|
|
|
|
|
|
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
|
|
|
|
|
|
|
std::function<void(int, int)> loadB = getLoadMatrixFn(
|
|
|
|
std::function<void(int, int)> loadB = getLoadMatrixFn(
|
|
|
|
B, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/,
|
|
|
|
B, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/,
|
|
|
|
{mmaInstrK, mmaInstrN} /*instrShpae*/,
|
|
|
|
{mmaInstrK, mmaInstrN} /*instrShpae*/,
|
|
|
|
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
|
|
|
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (aTensorTy.getEncoding()
|
|
|
|
|
|
|
|
.dyn_cast<SharedEncodingAttr>()) { // load from smem
|
|
|
|
|
|
|
|
loadA = getLoadMatrixFn(A, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
|
|
|
|
|
|
|
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
|
|
|
|
|
|
|
{matShapeM, matShapeK} /*matShape*/,
|
|
|
|
|
|
|
|
warpM /*warpId*/, ha /*vals*/);
|
|
|
|
|
|
|
|
} else if (auto blockedLayout =
|
|
|
|
|
|
|
|
aTensorTy.getEncoding()
|
|
|
|
|
|
|
|
.dyn_cast<BlockedEncodingAttr>()) { // load from registers,
|
|
|
|
|
|
|
|
// used in gemm fuse
|
|
|
|
|
|
|
|
// TODO(Superjomn) Port the logic.
|
|
|
|
|
|
|
|
assert(false && "Loading A from register is not supported yet.");
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
assert(false && "A's layout is not supported.");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const unsigned mStride = numRepN * 2;
|
|
|
|
const unsigned mStride = numRepN * 2;
|
|
|
|
SmallVector<Value> fc(numRepM * mStride + numRepN * 2);
|
|
|
|
SmallVector<Value> fc(numRepM * mStride + numRepN * 2);
|
|
|
|
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
|
|
|
|
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
|
|
|
@@ -2452,44 +2526,36 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
|
|
|
|
|
|
|
|
|
|
|
auto &mma = *builder.create(helper.getMmaInstr().str());
|
|
|
|
auto &mma = *builder.create(helper.getMmaInstr().str());
|
|
|
|
|
|
|
|
|
|
|
|
auto retArgs = builder.newListOperand();
|
|
|
|
auto retArgs = builder.newListOperand(4, "=r");
|
|
|
|
for (int i = 0; i < 4; ++i)
|
|
|
|
|
|
|
|
retArgs->listAppend(builder.newOperand("=r"));
|
|
|
|
|
|
|
|
auto aArg0 = builder.newOperand(ha[{m, k}], "r");
|
|
|
|
|
|
|
|
auto aArg1 = builder.newOperand(ha[{m + 1, k}], "r");
|
|
|
|
|
|
|
|
auto aArg2 = builder.newOperand(ha[{m, k + 1}], "r");
|
|
|
|
|
|
|
|
auto aArg3 = builder.newOperand(ha[{m + 1, k}], "r");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto bArg0 = builder.newOperand(ha[{n, k}], "r");
|
|
|
|
auto aArgs = builder.newListOperand({
|
|
|
|
auto bArg1 = builder.newOperand(ha[{n, k + 1}], "r");
|
|
|
|
{ha[{m, k}], "r"},
|
|
|
|
|
|
|
|
{ha[{m + 1, k}], "r"},
|
|
|
|
|
|
|
|
{ha[{m, k + 1}], "r"},
|
|
|
|
|
|
|
|
{ha[{m + 1, k + 1}], "r"},
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto bArgs =
|
|
|
|
|
|
|
|
builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}});
|
|
|
|
|
|
|
|
|
|
|
|
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
|
|
|
|
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
|
|
|
|
// shared layout or blocked layout, we will support them by expanding
|
|
|
|
// shared layout or blocked layout, we will support them by expanding
|
|
|
|
// convert_layout.
|
|
|
|
// convert_layout.
|
|
|
|
auto hc = helper.loadSplatLikeC(C, loc, rewriter);
|
|
|
|
auto hc = helper.loadSplatLikeC(C, loc, rewriter);
|
|
|
|
assert(hc.size() == 4UL && "Only splat-like C is supported now");
|
|
|
|
assert(hc.size() == 4UL && "Only splat-like C is supported now");
|
|
|
|
auto cArg0 = builder.newOperand(hc[0], "0"); // reuse the output registers
|
|
|
|
|
|
|
|
auto cArg1 = builder.newOperand(hc[1], "1");
|
|
|
|
|
|
|
|
auto cArg2 = builder.newOperand(hc[2], "2");
|
|
|
|
|
|
|
|
auto cArg3 = builder.newOperand(hc[3], "3");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mma({retArgs, aArg0, aArg1, aArg2, aArg3, bArg0, bArg1, cArg0, cArg1, cArg2,
|
|
|
|
auto cArgs = builder.newListOperand();
|
|
|
|
cArg3});
|
|
|
|
for (int i = 0; i < hc.size(); ++i) {
|
|
|
|
|
|
|
|
cArgs->listAppend(builder.newOperand(
|
|
|
|
|
|
|
|
hc[i], std::to_string(i))); // reuse the output registers
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
|
|
|
mma(retArgs, aArgs, bArgs, cArgs);
|
|
|
|
loc, helper.getMmaRetType(), builder.getAllMLIRArgs(), // operands
|
|
|
|
|
|
|
|
builder.dump(), // asm_string
|
|
|
|
Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType());
|
|
|
|
builder.getConstraints(), // constraints
|
|
|
|
|
|
|
|
true, // has_side_effects
|
|
|
|
|
|
|
|
false, // is_align_stack
|
|
|
|
|
|
|
|
LLVM::AsmDialectAttr::get(ctx,
|
|
|
|
|
|
|
|
LLVM::AsmDialect::AD_ATT), // asm_dialect
|
|
|
|
|
|
|
|
ArrayAttr::get(ctx, {}) // operand_attrs
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto mmaOut = inlineAsm.getRes();
|
|
|
|
|
|
|
|
auto getIntAttr = [&](int v) {
|
|
|
|
auto getIntAttr = [&](int v) {
|
|
|
|
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty(), v)});
|
|
|
|
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
fc[(m + 0) * mStride + (n * 2 + 0)] =
|
|
|
|
fc[(m + 0) * mStride + (n * 2 + 0)] =
|
|
|
@@ -2504,13 +2570,13 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
|
|
|
|
|
|
|
|
|
|
|
// Main program
|
|
|
|
// Main program
|
|
|
|
|
|
|
|
|
|
|
|
for (unsigned k = 0; k < numRepK; k++) {
|
|
|
|
for (unsigned k = 0; k < numRepK; ++k) {
|
|
|
|
for (unsigned m = 0; m < numRepM; m++)
|
|
|
|
for (unsigned m = 0; m < numRepM; ++m)
|
|
|
|
loadA(2 * m, 2 * k);
|
|
|
|
loadA(2 * m, 2 * k);
|
|
|
|
for (unsigned n = 0; n < numRepN; n += 2)
|
|
|
|
for (unsigned n = 0; n < numRepN; n += 2)
|
|
|
|
loadB(n, 2 * k);
|
|
|
|
loadB(n, 2 * k);
|
|
|
|
for (unsigned m = 0; m < numRepM; m++)
|
|
|
|
for (unsigned m = 0; m < numRepM; ++m)
|
|
|
|
for (unsigned n = 0; n < numRepN; n++) {
|
|
|
|
for (unsigned n = 0; n < numRepN; ++n) {
|
|
|
|
callMma(2 * m, n, 2 * k);
|
|
|
|
callMma(2 * m, n, 2 * k);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|