|
|
|
@@ -494,6 +494,10 @@ public:
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIndexType(), value));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----------------------------------------------------------------------
|
|
|
|
|
// Utilities
|
|
|
|
|
// -----------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
// Convert an \param index to a multi-dim coordinate given \param shape and
|
|
|
|
|
// \param order.
|
|
|
|
|
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
|
|
|
@@ -556,6 +560,10 @@ public:
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----------------------------------------------------------------------
|
|
|
|
|
// Blocked layout indices
|
|
|
|
|
// -----------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
// Get an index-base for each dimension for a \param blocked_layout.
|
|
|
|
|
SmallVector<Value>
|
|
|
|
|
emitBaseIndexForBlockedLayout(Location loc,
|
|
|
|
@@ -599,52 +607,6 @@ public:
|
|
|
|
|
return multiDimBase;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
|
|
|
|
ConversionPatternRewriter &b,
|
|
|
|
|
const Attribute &layout,
|
|
|
|
|
ArrayRef<int64_t> shape) const {
|
|
|
|
|
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
|
return emitIndicesForBlockedLayout(loc, b, blocked, shape);
|
|
|
|
|
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
|
|
|
|
return emitIndicesForSliceLayout(loc, b, slice, shape);
|
|
|
|
|
} else {
|
|
|
|
|
assert(0 && "emitIndices for layouts other than blocked & slice not "
|
|
|
|
|
"implemented yet");
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<SmallVector<Value>>
|
|
|
|
|
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const SliceEncodingAttr &sliceLayout,
|
|
|
|
|
ArrayRef<int64_t> shape) const {
|
|
|
|
|
auto parent = sliceLayout.getParent();
|
|
|
|
|
unsigned dim = sliceLayout.getDim();
|
|
|
|
|
size_t rank = shape.size();
|
|
|
|
|
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
|
auto paddedIndices = emitIndicesForBlockedLayout(
|
|
|
|
|
loc, rewriter, blockedParent, sliceLayout.paddedShape(shape));
|
|
|
|
|
unsigned numIndices = paddedIndices.size();
|
|
|
|
|
SmallVector<SmallVector<Value>> resultIndices(numIndices);
|
|
|
|
|
for (unsigned i = 0; i < numIndices; ++i)
|
|
|
|
|
for (unsigned d = 0; d < rank + 1; ++d)
|
|
|
|
|
if (d != dim)
|
|
|
|
|
resultIndices[i].push_back(paddedIndices[i][d]);
|
|
|
|
|
|
|
|
|
|
return resultIndices;
|
|
|
|
|
|
|
|
|
|
} else if (auto sliceParent = parent.dyn_cast<SliceEncodingAttr>()) {
|
|
|
|
|
assert(0 && "emitIndicesForSliceLayout with parent of sliceLayout"
|
|
|
|
|
"is not implemented yet");
|
|
|
|
|
return {};
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
assert(0 && "emitIndicesForSliceLayout with parent other than blocked & "
|
|
|
|
|
"slice not implemented yet");
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<SmallVector<unsigned>>
|
|
|
|
|
emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
|
|
|
|
|
ArrayRef<int64_t> shape) const {
|
|
|
|
@@ -696,23 +658,109 @@ public:
|
|
|
|
|
return reorderedOffset;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----------------------------------------------------------------------
|
|
|
|
|
// Mma layout indices
|
|
|
|
|
// -----------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
SmallVector<Value>
|
|
|
|
|
emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const MmaEncodingAttr &mmaLayout,
|
|
|
|
|
ArrayRef<int64_t> shape) const {
|
|
|
|
|
llvm_unreachable("emitIndicesForMmaLayoutV1 not implemented");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<SmallVector<unsigned>>
|
|
|
|
|
emitOffsetForMmaLayoutV1(const MmaEncodingAttr &mmaLayout,
|
|
|
|
|
ArrayRef<int64_t> shape) const {
|
|
|
|
|
llvm_unreachable("emitOffsetForMmaLayoutV1 not implemented");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<Value>
|
|
|
|
|
emitBaseIndexForMmaLayoutV2(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const MmaEncodingAttr &mmaLayout,
|
|
|
|
|
ArrayRef<int64_t> shape) const {
|
|
|
|
|
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
|
|
|
|
assert(_warpsPerCTA.size() == 2);
|
|
|
|
|
SmallVector<Value> warpsPerCTA = {idx_val(_warpsPerCTA[0]),
|
|
|
|
|
idx_val(_warpsPerCTA[1])};
|
|
|
|
|
Value threadId = getThreadId(rewriter, loc);
|
|
|
|
|
Value warpSize = idx_val(32);
|
|
|
|
|
Value laneId = urem(threadId, warpSize);
|
|
|
|
|
Value warpId = udiv(threadId, warpSize);
|
|
|
|
|
Value warpId0 = urem(warpId, warpsPerCTA[0]);
|
|
|
|
|
Value warpId1 = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]);
|
|
|
|
|
Value offWarp0 = mul(warpId0, idx_val(16));
|
|
|
|
|
Value offWarp1 = mul(warpId1, idx_val(8));
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> multiDimBase(2);
|
|
|
|
|
multiDimBase[0] = add(udiv(laneId, idx_val(4)), offWarp0);
|
|
|
|
|
multiDimBase[1] = add(mul(idx_val(2), urem(laneId, idx_val(4))), offWarp1);
|
|
|
|
|
return multiDimBase;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<SmallVector<unsigned>>
|
|
|
|
|
emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout,
|
|
|
|
|
ArrayRef<int64_t> shape) const {
|
|
|
|
|
SmallVector<SmallVector<unsigned>> ret;
|
|
|
|
|
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
|
|
|
|
|
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
|
|
|
|
|
ret.push_back({i, j});
|
|
|
|
|
ret.push_back({i, j + 1});
|
|
|
|
|
}
|
|
|
|
|
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
|
|
|
|
|
ret.push_back({i + 8, j});
|
|
|
|
|
ret.push_back({i + 8, j + 1});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----------------------------------------------------------------------
|
|
|
|
|
// Get offsets / indices for any layout
|
|
|
|
|
// -----------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> emitBaseIndexForLayout(Location loc,
|
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Attribute &layout,
|
|
|
|
|
ArrayRef<int64_t> shape) const {
|
|
|
|
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
|
|
|
|
return emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
|
|
|
|
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
|
|
|
if (mmaLayout.getVersion() == 1)
|
|
|
|
|
return emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
|
|
|
|
|
if (mmaLayout.getVersion() == 2)
|
|
|
|
|
return emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
|
|
|
|
|
}
|
|
|
|
|
llvm_unreachable("unsupported emitBaseIndexForLayout");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<SmallVector<unsigned>>
|
|
|
|
|
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
|
|
|
|
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
|
|
|
|
return emitOffsetForBlockedLayout(blockedLayout, shape);
|
|
|
|
|
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
|
|
|
if (mmaLayout.getVersion() == 1)
|
|
|
|
|
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
|
|
|
|
|
if (mmaLayout.getVersion() == 2)
|
|
|
|
|
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
|
|
|
|
|
}
|
|
|
|
|
llvm_unreachable("unsupported emitOffsetForLayout");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Emit indices calculation within each ConversionPattern, and returns a
|
|
|
|
|
// [elemsPerThread X rank] index matrix.
|
|
|
|
|
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
|
|
|
|
// be eliminated in the consequent MLIR/LLVM optimization. We might
|
|
|
|
|
// implement an indexCache if necessary.
|
|
|
|
|
SmallVector<SmallVector<Value>>
|
|
|
|
|
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const BlockedEncodingAttr &blockedLayout,
|
|
|
|
|
ArrayRef<int64_t> shape) const {
|
|
|
|
|
|
|
|
|
|
// TODO: [phil] redundant indices commputation do not appear to hurt
|
|
|
|
|
// performance much, but they could still significantly slow down
|
|
|
|
|
// computations.
|
|
|
|
|
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
|
|
|
|
|
Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Attribute &layout, ArrayRef<int64_t> shape) const {
|
|
|
|
|
|
|
|
|
|
// step 1, delinearize threadId to get the base index
|
|
|
|
|
auto multiDimBase =
|
|
|
|
|
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
|
|
|
|
|
|
|
|
|
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape);
|
|
|
|
|
// step 2, get offset of each element
|
|
|
|
|
SmallVector<SmallVector<unsigned>> offset =
|
|
|
|
|
emitOffsetForBlockedLayout(blockedLayout, shape);
|
|
|
|
|
|
|
|
|
|
auto offset = emitOffsetForLayout(layout, shape);
|
|
|
|
|
// step 3, add offset to base, and reorder the sequence of indices to
|
|
|
|
|
// guarantee that elems in the same sizePerThread are adjacent in order
|
|
|
|
|
unsigned rank = shape.size();
|
|
|
|
@@ -726,6 +774,49 @@ public:
|
|
|
|
|
return multiDimIdx;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<SmallVector<Value>>
|
|
|
|
|
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const SliceEncodingAttr &sliceLayout,
|
|
|
|
|
ArrayRef<int64_t> shape) const {
|
|
|
|
|
auto parent = sliceLayout.getParent();
|
|
|
|
|
unsigned dim = sliceLayout.getDim();
|
|
|
|
|
size_t rank = shape.size();
|
|
|
|
|
auto paddedIndices =
|
|
|
|
|
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
|
|
|
|
|
unsigned numIndices = paddedIndices.size();
|
|
|
|
|
SmallVector<SmallVector<Value>> resultIndices(numIndices);
|
|
|
|
|
for (unsigned i = 0; i < numIndices; ++i)
|
|
|
|
|
for (unsigned d = 0; d < rank + 1; ++d)
|
|
|
|
|
if (d != dim)
|
|
|
|
|
resultIndices[i].push_back(paddedIndices[i][d]);
|
|
|
|
|
|
|
|
|
|
return resultIndices;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----------------------------------------------------------------------
|
|
|
|
|
// Emit indices
|
|
|
|
|
// -----------------------------------------------------------------------
|
|
|
|
|
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
|
|
|
|
ConversionPatternRewriter &b,
|
|
|
|
|
const Attribute &layout,
|
|
|
|
|
ArrayRef<int64_t> shape) const {
|
|
|
|
|
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
|
return emitIndicesForDistributedLayout(loc, b, blocked, shape);
|
|
|
|
|
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
|
|
|
return emitIndicesForDistributedLayout(loc, b, mma, shape);
|
|
|
|
|
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
|
|
|
|
return emitIndicesForSliceLayout(loc, b, slice, shape);
|
|
|
|
|
} else {
|
|
|
|
|
assert(0 && "emitIndices for layouts other than blocked & slice not "
|
|
|
|
|
"implemented yet");
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----------------------------------------------------------------------
|
|
|
|
|
// Shared memory utilities
|
|
|
|
|
// -----------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
T value) const {
|
|
|
|
@@ -954,8 +1045,8 @@ struct LoadOpConversion
|
|
|
|
|
|
|
|
|
|
// Determine the vectorization size
|
|
|
|
|
Type valueTy = op.getResult().getType();
|
|
|
|
|
Type valueElemTy = typeConverter->convertType(
|
|
|
|
|
getElementTypeOrSelf(valueTy));
|
|
|
|
|
Type valueElemTy =
|
|
|
|
|
typeConverter->convertType(getElementTypeOrSelf(valueTy));
|
|
|
|
|
unsigned vec = getVectorSize(ptr);
|
|
|
|
|
unsigned numElems = getElemsPerThread(ptr.getType());
|
|
|
|
|
if (llMask)
|
|
|
|
@@ -1086,7 +1177,8 @@ struct LoadOpConversion
|
|
|
|
|
: retTys[0];
|
|
|
|
|
|
|
|
|
|
// TODO: if (has_l2_evict_policy)
|
|
|
|
|
// auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
|
|
|
|
// auto asmDialectAttr =
|
|
|
|
|
// LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
|
|
|
|
// LLVM::AsmDialect::AD_ATT);
|
|
|
|
|
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
|
|
|
|
|
|
|
|
|
@@ -1149,8 +1241,8 @@ struct StoreOpConversion
|
|
|
|
|
MLIRContext *ctx = rewriter.getContext();
|
|
|
|
|
|
|
|
|
|
auto valueTy = value.getType();
|
|
|
|
|
Type valueElemTy = typeConverter->convertType(
|
|
|
|
|
getElementTypeOrSelf(valueTy));
|
|
|
|
|
Type valueElemTy =
|
|
|
|
|
typeConverter->convertType(getElementTypeOrSelf(valueTy));
|
|
|
|
|
|
|
|
|
|
unsigned vec = getVectorSize(ptr);
|
|
|
|
|
unsigned numElems = getElemsPerThread(ptr.getType());
|
|
|
|
@@ -1571,19 +1663,19 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|
|
|
|
unsigned axis = adaptor.axis();
|
|
|
|
|
|
|
|
|
|
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
|
|
|
|
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
|
|
|
|
auto srcLayout = srcTy.getEncoding();
|
|
|
|
|
auto srcShape = srcTy.getShape();
|
|
|
|
|
auto srcRank = srcTy.getRank();
|
|
|
|
|
|
|
|
|
|
auto threadsPerWarp = srcLayout.getThreadsPerWarp();
|
|
|
|
|
auto warpsPerCTA = srcLayout.getWarpsPerCTA();
|
|
|
|
|
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
|
|
|
|
|
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout);
|
|
|
|
|
|
|
|
|
|
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
|
|
|
|
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
|
|
|
|
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
|
|
|
|
smemBase = bitcast(smemBase, elemPtrTy);
|
|
|
|
|
|
|
|
|
|
auto order = srcLayout.getOrder();
|
|
|
|
|
auto order = getOrder(srcLayout);
|
|
|
|
|
unsigned sizeIntraWarps = threadsPerWarp[axis];
|
|
|
|
|
unsigned sizeInterWarps = warpsPerCTA[axis];
|
|
|
|
|
|
|
|
|
@@ -1592,7 +1684,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|
|
|
|
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
|
|
|
|
|
|
|
|
|
SmallVector<SmallVector<unsigned>> offset =
|
|
|
|
|
emitOffsetForBlockedLayout(srcLayout, srcShape);
|
|
|
|
|
emitOffsetForLayout(srcLayout, srcShape);
|
|
|
|
|
|
|
|
|
|
std::map<SmallVector<unsigned>, Value> accs;
|
|
|
|
|
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
|
|
|
@@ -1655,7 +1747,8 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|
|
|
|
// each thread needs to process:
|
|
|
|
|
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
|
|
|
|
|
unsigned elems = product<unsigned>(smemShape);
|
|
|
|
|
unsigned numThreads = product<unsigned>(srcLayout.getWarpsPerCTA()) * 32;
|
|
|
|
|
unsigned numThreads =
|
|
|
|
|
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) * 32;
|
|
|
|
|
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
|
|
|
|
|
Value readOffset = threadId;
|
|
|
|
|
for (unsigned round = 0; round < elemsPerThread; ++round) {
|
|
|
|
@@ -2104,9 +2197,10 @@ struct FpToFpOpConversion
|
|
|
|
|
using ConvertTritonGPUOpToLLVMPattern<
|
|
|
|
|
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
static SmallVector<Value> convertFp8x4ToFp16x4(
|
|
|
|
|
Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
|
|
|
|
static SmallVector<Value>
|
|
|
|
|
convertFp8x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value &v0, const Value &v1, const Value &v2,
|
|
|
|
|
const Value &v3) {
|
|
|
|
|
auto ctx = rewriter.getContext();
|
|
|
|
|
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
|
|
|
|
Value fp8x4Vec = undef(fp8x4VecTy);
|
|
|
|
@@ -2117,8 +2211,7 @@ struct FpToFpOpConversion
|
|
|
|
|
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
|
|
|
|
|
|
|
|
|
|
PTXBuilder builder;
|
|
|
|
|
auto *ptxAsm =
|
|
|
|
|
"{ \n"
|
|
|
|
|
auto *ptxAsm = "{ \n"
|
|
|
|
|
".reg .b32 a<2>, b<2>; \n"
|
|
|
|
|
"prmt.b32 a0, 0, $2, 0x5040; \n"
|
|
|
|
|
"prmt.b32 a1, 0, $2, 0x7060; \n"
|
|
|
|
@@ -2141,21 +2234,20 @@ struct FpToFpOpConversion
|
|
|
|
|
struct_ty(SmallVector<Type>{fp16x2VecTy, fp16x2VecTy});
|
|
|
|
|
auto fp16x2x2Struct =
|
|
|
|
|
builder.launch(rewriter, loc, fp16x2x2StructTy, false);
|
|
|
|
|
auto fp16x2Vec0 = extract_val(fp16x2VecTy, fp16x2x2Struct,
|
|
|
|
|
rewriter.getI32ArrayAttr({0}));
|
|
|
|
|
auto fp16x2Vec1 = extract_val(fp16x2VecTy, fp16x2x2Struct,
|
|
|
|
|
rewriter.getI32ArrayAttr({1}));
|
|
|
|
|
return {
|
|
|
|
|
extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
|
|
|
|
|
auto fp16x2Vec0 =
|
|
|
|
|
extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({0}));
|
|
|
|
|
auto fp16x2Vec1 =
|
|
|
|
|
extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({1}));
|
|
|
|
|
return {extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
|
|
|
|
|
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
|
|
|
|
|
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
|
|
|
|
|
extract_element(f16_ty, fp16x2Vec1, i32_val(1))
|
|
|
|
|
};
|
|
|
|
|
extract_element(f16_ty, fp16x2Vec1, i32_val(1))};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static SmallVector<Value> convertFp16x4ToFp8x4(
|
|
|
|
|
Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
|
|
|
|
static SmallVector<Value>
|
|
|
|
|
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value &v0, const Value &v1, const Value &v2,
|
|
|
|
|
const Value &v3) {
|
|
|
|
|
auto ctx = rewriter.getContext();
|
|
|
|
|
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
|
|
|
|
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
|
|
|
@@ -2168,8 +2260,7 @@ struct FpToFpOpConversion
|
|
|
|
|
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
|
|
|
|
|
|
|
|
|
|
PTXBuilder builder;
|
|
|
|
|
auto *ptxAsm =
|
|
|
|
|
"{ \n"
|
|
|
|
|
auto *ptxAsm = "{ \n"
|
|
|
|
|
".reg .b32 a<2>, b<2>; \n"
|
|
|
|
|
"shl.b32 a0, $1, 1; \n"
|
|
|
|
|
"shl.b32 a1, $2, 1; \n"
|
|
|
|
@@ -2190,17 +2281,16 @@ struct FpToFpOpConversion
|
|
|
|
|
|
|
|
|
|
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
|
|
|
|
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
|
|
|
|
|
return {
|
|
|
|
|
extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
|
|
|
|
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
|
|
|
|
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
|
|
|
|
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
|
|
|
|
extract_element(i8_ty, fp8x4Vec, i32_val(3))
|
|
|
|
|
};
|
|
|
|
|
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static SmallVector<Value> convertFp8x4ToBf16x4(
|
|
|
|
|
Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
|
|
|
|
static SmallVector<Value>
|
|
|
|
|
convertFp8x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value &v0, const Value &v1, const Value &v2,
|
|
|
|
|
const Value &v3) {
|
|
|
|
|
auto ctx = rewriter.getContext();
|
|
|
|
|
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
|
|
|
|
Value fp8x4Vec = undef(fp8x4VecTy);
|
|
|
|
@@ -2211,8 +2301,7 @@ struct FpToFpOpConversion
|
|
|
|
|
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
|
|
|
|
|
|
|
|
|
|
PTXBuilder builder;
|
|
|
|
|
auto *ptxAsm =
|
|
|
|
|
"{ \n"
|
|
|
|
|
auto *ptxAsm = "{ \n"
|
|
|
|
|
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
|
|
|
|
|
"prmt.b32 a0, 0, $2, 0x5040; \n"
|
|
|
|
|
"prmt.b32 a1, 0, $2, 0x7060; \n"
|
|
|
|
@@ -2239,21 +2328,20 @@ struct FpToFpOpConversion
|
|
|
|
|
struct_ty(SmallVector<Type>{bf16x2VecTy, bf16x2VecTy});
|
|
|
|
|
auto bf16x2x2Struct =
|
|
|
|
|
builder.launch(rewriter, loc, bf16x2x2StructTy, false);
|
|
|
|
|
auto bf16x2Vec0 = extract_val(bf16x2VecTy, bf16x2x2Struct,
|
|
|
|
|
rewriter.getI32ArrayAttr({0}));
|
|
|
|
|
auto bf16x2Vec1 = extract_val(bf16x2VecTy, bf16x2x2Struct,
|
|
|
|
|
rewriter.getI32ArrayAttr({1}));
|
|
|
|
|
return {
|
|
|
|
|
extract_element(bf16_ty, bf16x2Vec0, i32_val(0)),
|
|
|
|
|
auto bf16x2Vec0 =
|
|
|
|
|
extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({0}));
|
|
|
|
|
auto bf16x2Vec1 =
|
|
|
|
|
extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({1}));
|
|
|
|
|
return {extract_element(bf16_ty, bf16x2Vec0, i32_val(0)),
|
|
|
|
|
extract_element(bf16_ty, bf16x2Vec0, i32_val(1)),
|
|
|
|
|
extract_element(bf16_ty, bf16x2Vec1, i32_val(0)),
|
|
|
|
|
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))
|
|
|
|
|
};
|
|
|
|
|
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static SmallVector<Value> convertBf16x4ToFp8x4(
|
|
|
|
|
Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
|
|
|
|
static SmallVector<Value>
|
|
|
|
|
convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value &v0, const Value &v1, const Value &v2,
|
|
|
|
|
const Value &v3) {
|
|
|
|
|
auto ctx = rewriter.getContext();
|
|
|
|
|
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
|
|
|
|
|
Value bf16x2Vec0 = undef(bf16x2VecTy);
|
|
|
|
@@ -2266,8 +2354,7 @@ struct FpToFpOpConversion
|
|
|
|
|
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
|
|
|
|
|
|
|
|
|
|
PTXBuilder builder;
|
|
|
|
|
auto *ptxAsm =
|
|
|
|
|
"{ \n"
|
|
|
|
|
auto *ptxAsm = "{ \n"
|
|
|
|
|
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n"
|
|
|
|
|
".reg .u32 fp8_min, fp8_max, rn_, zero; \n"
|
|
|
|
|
"mov.u32 fp8_min, 0x38003800; \n"
|
|
|
|
@@ -2312,29 +2399,27 @@ struct FpToFpOpConversion
|
|
|
|
|
|
|
|
|
|
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
|
|
|
|
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
|
|
|
|
|
return {
|
|
|
|
|
extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
|
|
|
|
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
|
|
|
|
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
|
|
|
|
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
|
|
|
|
extract_element(i8_ty, fp8x4Vec, i32_val(3))
|
|
|
|
|
};
|
|
|
|
|
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static SmallVector<Value> convertFp8x4ToFp32x4(
|
|
|
|
|
Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
|
|
|
|
static SmallVector<Value>
|
|
|
|
|
convertFp8x4ToFp32x4(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value &v0, const Value &v1, const Value &v2,
|
|
|
|
|
const Value &v3) {
|
|
|
|
|
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
|
|
|
|
|
return {
|
|
|
|
|
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
|
|
|
|
|
return {rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
|
|
|
|
|
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[1]),
|
|
|
|
|
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[2]),
|
|
|
|
|
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])
|
|
|
|
|
};
|
|
|
|
|
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static SmallVector<Value> convertFp32x4ToFp8x4(
|
|
|
|
|
Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
|
|
|
|
static SmallVector<Value>
|
|
|
|
|
convertFp32x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value &v0, const Value &v1, const Value &v2,
|
|
|
|
|
const Value &v3) {
|
|
|
|
|
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
|
|
|
|
|
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
|
|
|
|
|
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
|
|
|
|
@@ -2342,21 +2427,21 @@ struct FpToFpOpConversion
|
|
|
|
|
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static SmallVector<Value> convertFp8x4ToFp64x4(
|
|
|
|
|
Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
|
|
|
|
static SmallVector<Value>
|
|
|
|
|
convertFp8x4ToFp64x4(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value &v0, const Value &v1, const Value &v2,
|
|
|
|
|
const Value &v3) {
|
|
|
|
|
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
|
|
|
|
|
return {
|
|
|
|
|
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
|
|
|
|
|
return {rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
|
|
|
|
|
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[1]),
|
|
|
|
|
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[2]),
|
|
|
|
|
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[3])
|
|
|
|
|
};
|
|
|
|
|
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[3])};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static SmallVector<Value> convertFp64x4ToFp8x4(
|
|
|
|
|
Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
|
|
|
|
static SmallVector<Value>
|
|
|
|
|
convertFp64x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
const Value &v0, const Value &v1, const Value &v2,
|
|
|
|
|
const Value &v3) {
|
|
|
|
|
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
|
|
|
|
|
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
|
|
|
|
|
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
|
|
|
|
@@ -2381,7 +2466,8 @@ struct FpToFpOpConversion
|
|
|
|
|
// Select convertor
|
|
|
|
|
std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
|
|
|
|
|
const Value &, const Value &,
|
|
|
|
|
const Value&, const Value&)> convertor;
|
|
|
|
|
const Value &, const Value &)>
|
|
|
|
|
convertor;
|
|
|
|
|
if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF16()) {
|
|
|
|
|
convertor = convertFp8x4ToFp16x4;
|
|
|
|
|
} else if (srcEltType.isF16() && dstEltType.isa<triton::Float8Type>()) {
|
|
|
|
@@ -2410,8 +2496,7 @@ struct FpToFpOpConversion
|
|
|
|
|
auto elements = getElementsFromStruct(loc, adaptor.from(), rewriter);
|
|
|
|
|
SmallVector<Value> resultVals;
|
|
|
|
|
for (size_t i = 0; i < elems; i += 4) {
|
|
|
|
|
auto converted = convertor(loc, rewriter,
|
|
|
|
|
elements[i], elements[i + 1],
|
|
|
|
|
auto converted = convertor(loc, rewriter, elements[i], elements[i + 1],
|
|
|
|
|
elements[i + 2], elements[i + 3]);
|
|
|
|
|
resultVals.append(converted);
|
|
|
|
|
}
|
|
|
|
@@ -2626,6 +2711,82 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
|
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
|
|
unsigned elemId, ArrayRef<int64_t> shape,
|
|
|
|
|
ArrayRef<unsigned> multiDimCTAInRepId,
|
|
|
|
|
ArrayRef<unsigned> shapePerCTA) const {
|
|
|
|
|
unsigned rank = shape.size();
|
|
|
|
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
|
auto multiDimOffsetFirstElem =
|
|
|
|
|
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
|
|
|
|
SmallVector<Value> multiDimOffset(rank);
|
|
|
|
|
SmallVector<unsigned> multiDimElemId =
|
|
|
|
|
getMultiDimIndex<unsigned>(elemId, blockedLayout.getSizePerThread());
|
|
|
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
|
|
|
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
|
|
|
|
|
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
|
|
|
|
multiDimElemId[d]));
|
|
|
|
|
}
|
|
|
|
|
return multiDimOffset;
|
|
|
|
|
}
|
|
|
|
|
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
|
|
|
|
unsigned dim = sliceLayout.getDim();
|
|
|
|
|
auto multiDimOffsetParent =
|
|
|
|
|
getMultiDimOffset(sliceLayout.getParent(), loc, rewriter, elemId,
|
|
|
|
|
sliceLayout.paddedShape(shape),
|
|
|
|
|
sliceLayout.paddedShape(multiDimCTAInRepId),
|
|
|
|
|
sliceLayout.paddedShape(shapePerCTA));
|
|
|
|
|
SmallVector<Value> multiDimOffset(rank);
|
|
|
|
|
for (unsigned d = 0; d < rank + 1; ++d) {
|
|
|
|
|
if (d == dim)
|
|
|
|
|
continue;
|
|
|
|
|
unsigned slicedD = d < dim ? d : (d - 1);
|
|
|
|
|
multiDimOffset[slicedD] = multiDimOffsetParent[d];
|
|
|
|
|
}
|
|
|
|
|
return multiDimOffset;
|
|
|
|
|
}
|
|
|
|
|
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
|
|
|
SmallVector<Value> mmaColIdx(2);
|
|
|
|
|
SmallVector<Value> mmaRowIdx(2);
|
|
|
|
|
Value threadId = getThreadId(rewriter, loc);
|
|
|
|
|
Value warpSize = idx_val(32);
|
|
|
|
|
Value laneId = urem(threadId, warpSize);
|
|
|
|
|
Value warpId = udiv(threadId, warpSize);
|
|
|
|
|
// auto multiDimWarpId =
|
|
|
|
|
// delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA());
|
|
|
|
|
// TODO: double confirm if its document bug or DotConversion's Bug
|
|
|
|
|
SmallVector<Value> multiDimWarpId(2);
|
|
|
|
|
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
|
|
|
|
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
|
|
|
|
Value four = idx_val(4);
|
|
|
|
|
Value mmaGrpId = udiv(laneId, four);
|
|
|
|
|
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8));
|
|
|
|
|
Value mmaThreadIdInGrp = urem(laneId, four);
|
|
|
|
|
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2));
|
|
|
|
|
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1));
|
|
|
|
|
Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16));
|
|
|
|
|
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
|
|
|
|
|
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
|
|
|
|
|
Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8));
|
|
|
|
|
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
|
|
|
|
|
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
|
|
|
|
|
|
|
|
|
|
assert(rank == 2);
|
|
|
|
|
assert(mmaLayout.getVersion() == 2 &&
|
|
|
|
|
"mmaLayout ver1 not implemented yet");
|
|
|
|
|
SmallVector<Value> multiDimOffset(rank);
|
|
|
|
|
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
|
|
|
|
|
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
|
|
|
|
|
multiDimOffset[0] = add(multiDimOffset[0],
|
|
|
|
|
idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
|
|
|
|
multiDimOffset[1] = add(multiDimOffset[1],
|
|
|
|
|
idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
|
|
|
|
return multiDimOffset;
|
|
|
|
|
}
|
|
|
|
|
llvm_unreachable("unexpected layout in getMultiDimOffset");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// shared memory rd/st for blocked or mma layout with data padding
|
|
|
|
|
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
bool stNotRd, RankedTensorType type,
|
|
|
|
@@ -2693,47 +2854,6 @@ void ConvertLayoutOpConversion::processReplica(
|
|
|
|
|
elemTy = IntegerType::get(elemTy.getContext(), 8);
|
|
|
|
|
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> multiDimOffsetFirstElem;
|
|
|
|
|
SmallVector<Value> mmaColIdx(2);
|
|
|
|
|
SmallVector<Value> mmaRowIdx(2);
|
|
|
|
|
if (blockedLayout) {
|
|
|
|
|
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
|
|
|
|
loc, rewriter, blockedLayout, type.getShape());
|
|
|
|
|
} else if (sliceLayout) {
|
|
|
|
|
auto parent = sliceLayout.getParent();
|
|
|
|
|
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
|
SmallVector<int64_t> paddedShape =
|
|
|
|
|
sliceLayout.paddedShape(type.getShape());
|
|
|
|
|
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
|
|
|
|
loc, rewriter, blockedParent, paddedShape);
|
|
|
|
|
} else {
|
|
|
|
|
assert(0 && "SliceEncodingAttr with parent other than "
|
|
|
|
|
"BlockedEncodingAttr not implemented");
|
|
|
|
|
}
|
|
|
|
|
} else if (mmaLayout) {
|
|
|
|
|
Value threadId = getThreadId(rewriter, loc);
|
|
|
|
|
Value warpSize = idx_val(32);
|
|
|
|
|
Value laneId = urem(threadId, warpSize);
|
|
|
|
|
Value warpId = udiv(threadId, warpSize);
|
|
|
|
|
// auto multiDimWarpId =
|
|
|
|
|
// delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA());
|
|
|
|
|
// TODO: double confirm if its document bug or DotConversion's Bug
|
|
|
|
|
SmallVector<Value> multiDimWarpId(2);
|
|
|
|
|
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
|
|
|
|
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
|
|
|
|
Value four = idx_val(4);
|
|
|
|
|
Value mmaGrpId = udiv(laneId, four);
|
|
|
|
|
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8));
|
|
|
|
|
Value mmaThreadIdInGrp = urem(laneId, four);
|
|
|
|
|
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2));
|
|
|
|
|
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1));
|
|
|
|
|
Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16));
|
|
|
|
|
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
|
|
|
|
|
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
|
|
|
|
|
Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8));
|
|
|
|
|
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
|
|
|
|
|
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
|
|
|
|
|
}
|
|
|
|
|
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
|
|
|
|
auto multiDimCTAInRepId = getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
|
|
|
|
|
SmallVector<unsigned> multiDimCTAId(rank);
|
|
|
|
@@ -2747,48 +2867,9 @@ void ConvertLayoutOpConversion::processReplica(
|
|
|
|
|
// consider of caching the index calculation result in case
|
|
|
|
|
// of performance issue observed.
|
|
|
|
|
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
|
|
|
|
SmallVector<Value> multiDimOffset(rank);
|
|
|
|
|
if (blockedLayout) {
|
|
|
|
|
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
|
|
|
|
elemId, blockedLayout.getSizePerThread());
|
|
|
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
|
|
|
multiDimOffset[d] =
|
|
|
|
|
add(multiDimOffsetFirstElem[d],
|
|
|
|
|
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
|
|
|
|
multiDimElemId[d]));
|
|
|
|
|
}
|
|
|
|
|
} else if (sliceLayout) {
|
|
|
|
|
unsigned dim = sliceLayout.getDim();
|
|
|
|
|
auto parent = sliceLayout.getParent();
|
|
|
|
|
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
|
|
|
|
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
|
|
|
|
elemId, blockedParent.getSizePerThread());
|
|
|
|
|
for (unsigned d = 0; d < rank + 1; ++d) {
|
|
|
|
|
if (d == dim)
|
|
|
|
|
continue;
|
|
|
|
|
unsigned slicedD = d < dim ? d : (d - 1);
|
|
|
|
|
multiDimOffset[slicedD] =
|
|
|
|
|
add(multiDimOffsetFirstElem[d],
|
|
|
|
|
idx_val(multiDimCTAInRepId[slicedD] * shapePerCTA[slicedD] +
|
|
|
|
|
multiDimElemId[d]));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
assert(0 && "SliceEncodingAttr with parent other than "
|
|
|
|
|
"BlockedEncodingAttr not implemented");
|
|
|
|
|
}
|
|
|
|
|
} else if (mmaLayout) {
|
|
|
|
|
assert(rank == 2);
|
|
|
|
|
assert(mmaLayout.getVersion() == 2 &&
|
|
|
|
|
"mmaLayout ver1 not implemented yet");
|
|
|
|
|
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
|
|
|
|
|
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
|
|
|
|
|
multiDimOffset[0] = add(
|
|
|
|
|
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
|
|
|
|
multiDimOffset[1] = add(
|
|
|
|
|
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
|
|
|
|
} else {
|
|
|
|
|
assert(0 && "unexpected layout in processReplica");
|
|
|
|
|
}
|
|
|
|
|
SmallVector<Value> multiDimOffset =
|
|
|
|
|
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
|
|
|
|
|
multiDimCTAInRepId, shapePerCTA);
|
|
|
|
|
Value offset =
|
|
|
|
|
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
|
|
|
|
|
reorder<unsigned>(paddedRepShape, outOrd));
|
|
|
|
@@ -3637,10 +3718,10 @@ struct DotOpMmaV1ConversionHelper {
|
|
|
|
|
// Compute the offset of the matrix to load.
|
|
|
|
|
// Returns offsetAM, offsetAK, offsetBN, offsetBK.
|
|
|
|
|
// NOTE, the information M(from $a) and N(from $b) couldn't be retrieved at
|
|
|
|
|
// the same time in the usage in convert_layout[shared->dot_op], we leave the
|
|
|
|
|
// noexist info to be 0 and only use the desired argument from the composed
|
|
|
|
|
// result. In this way we want to retain the original code structure in
|
|
|
|
|
// convert_mma884 method for easier debugging.
|
|
|
|
|
// the same time in the usage in convert_layout[shared->dot_op], we leave
|
|
|
|
|
// the noexist info to be 0 and only use the desired argument from the
|
|
|
|
|
// composed result. In this way we want to retain the original code
|
|
|
|
|
// structure in convert_mma884 method for easier debugging.
|
|
|
|
|
std::tuple<Value, Value, Value, Value>
|
|
|
|
|
computeOffsets(Value threadId, bool isARow, bool isBRow, ArrayRef<int> fpw,
|
|
|
|
|
ArrayRef<int> spw, ArrayRef<int> rep,
|
|
|
|
@@ -5231,8 +5312,8 @@ struct InsertSliceAsyncOpConversion
|
|
|
|
|
srcStrides.emplace_back(smemObj.strides[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Compute the offset based on the original dimensions of the shared memory
|
|
|
|
|
// object
|
|
|
|
|
// Compute the offset based on the original dimensions of the shared
|
|
|
|
|
// memory object
|
|
|
|
|
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
|
|
|
|
|
auto dstPtrTy =
|
|
|
|
|
ptr_ty(getTypeConverter()->convertType(resTy.getElementType()), 3);
|
|
|
|
|