[BACKEND] Added support for mma layouts in reductions (#863)
Validated hackily by manually modifying the reduction .ttgir in my local cache. There will be a follow-up PR adding some better testing infrastructure to test out conversions and reductions on arbitrary layouts.
This commit is contained in:
@@ -87,22 +87,22 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
|
||||
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto axis = op.axis();
|
||||
|
||||
bool fastReduce = axis == srcLayout.getOrder()[0];
|
||||
bool fastReduce = axis == getOrder(srcLayout)[0];
|
||||
|
||||
SmallVector<unsigned> smemShape;
|
||||
for (auto d : srcShape)
|
||||
smemShape.push_back(d);
|
||||
|
||||
if (fastReduce) {
|
||||
unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis];
|
||||
unsigned sizeInterWarps = gpu::getWarpsPerCTA(srcLayout)[axis];
|
||||
smemShape[axis] = sizeInterWarps;
|
||||
} else {
|
||||
unsigned threadsPerCTAAxis =
|
||||
srcLayout.getThreadsPerWarp()[axis] * srcLayout.getWarpsPerCTA()[axis];
|
||||
unsigned threadsPerCTAAxis = gpu::getThreadsPerWarp(srcLayout)[axis] *
|
||||
gpu::getWarpsPerCTA(srcLayout)[axis];
|
||||
smemShape[axis] = threadsPerCTAAxis;
|
||||
}
|
||||
|
||||
@@ -161,16 +161,11 @@ private:
|
||||
// TODO(Keren): Reduce with index is not supported yet.
|
||||
auto value = op->getOperand(0);
|
||||
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
|
||||
if (tensorType.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
auto smemShape = getScratchConfigForReduce(reduceOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(),
|
||||
1, std::multiplies{});
|
||||
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
} else {
|
||||
assert(0 && "ReduceOp with input layout other than blocked layout is "
|
||||
"not implemented yet");
|
||||
}
|
||||
auto smemShape = getScratchConfigForReduce(reduceOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
}
|
||||
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
|
||||
|
@@ -339,7 +339,7 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
}
|
||||
|
||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
||||
for (const auto& v : llvm::enumerate(resultVals)) {
|
||||
for (const auto &v : llvm::enumerate(resultVals)) {
|
||||
assert(v.value() && "can not insert null values");
|
||||
llvmStruct = insert_val(structType, llvmStruct, v.value(),
|
||||
rewriter.getI64ArrayAttr(v.index()));
|
||||
@@ -494,6 +494,10 @@ public:
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), value));
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Utilities
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// Convert an \param index to a multi-dim coordinate given \param shape and
|
||||
// \param order.
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
@@ -556,6 +560,10 @@ public:
|
||||
return ret;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Blocked layout indices
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// Get an index-base for each dimension for a \param blocked_layout.
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForBlockedLayout(Location loc,
|
||||
@@ -599,52 +607,6 @@ public:
|
||||
return multiDimBase;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||
ConversionPatternRewriter &b,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return emitIndicesForBlockedLayout(loc, b, blocked, shape);
|
||||
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
return emitIndicesForSliceLayout(loc, b, slice, shape);
|
||||
} else {
|
||||
assert(0 && "emitIndices for layouts other than blocked & slice not "
|
||||
"implemented yet");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SliceEncodingAttr &sliceLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto parent = sliceLayout.getParent();
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
size_t rank = shape.size();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
auto paddedIndices = emitIndicesForBlockedLayout(
|
||||
loc, rewriter, blockedParent, sliceLayout.paddedShape(shape));
|
||||
unsigned numIndices = paddedIndices.size();
|
||||
SmallVector<SmallVector<Value>> resultIndices(numIndices);
|
||||
for (unsigned i = 0; i < numIndices; ++i)
|
||||
for (unsigned d = 0; d < rank + 1; ++d)
|
||||
if (d != dim)
|
||||
resultIndices[i].push_back(paddedIndices[i][d]);
|
||||
|
||||
return resultIndices;
|
||||
|
||||
} else if (auto sliceParent = parent.dyn_cast<SliceEncodingAttr>()) {
|
||||
assert(0 && "emitIndicesForSliceLayout with parent of sliceLayout"
|
||||
"is not implemented yet");
|
||||
return {};
|
||||
|
||||
} else {
|
||||
assert(0 && "emitIndicesForSliceLayout with parent other than blocked & "
|
||||
"slice not implemented yet");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
@@ -696,23 +658,109 @@ public:
|
||||
return reorderedOffset;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Mma layout indices
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
llvm_unreachable("emitIndicesForMmaLayoutV1 not implemented");
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForMmaLayoutV1(const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
llvm_unreachable("emitOffsetForMmaLayoutV1 not implemented");
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForMmaLayoutV2(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
assert(_warpsPerCTA.size() == 2);
|
||||
SmallVector<Value> warpsPerCTA = {idx_val(_warpsPerCTA[0]),
|
||||
idx_val(_warpsPerCTA[1])};
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = idx_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value warpId0 = urem(warpId, warpsPerCTA[0]);
|
||||
Value warpId1 = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]);
|
||||
Value offWarp0 = mul(warpId0, idx_val(16));
|
||||
Value offWarp1 = mul(warpId1, idx_val(8));
|
||||
|
||||
SmallVector<Value> multiDimBase(2);
|
||||
multiDimBase[0] = add(udiv(laneId, idx_val(4)), offWarp0);
|
||||
multiDimBase[1] = add(mul(idx_val(2), urem(laneId, idx_val(4))), offWarp1);
|
||||
return multiDimBase;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
SmallVector<SmallVector<unsigned>> ret;
|
||||
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
|
||||
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
|
||||
ret.push_back({i, j});
|
||||
ret.push_back({i, j + 1});
|
||||
}
|
||||
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
|
||||
ret.push_back({i + 8, j});
|
||||
ret.push_back({i + 8, j + 1});
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Get offsets / indices for any layout
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
SmallVector<Value> emitBaseIndexForLayout(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
||||
return emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersion() == 1)
|
||||
return emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
|
||||
if (mmaLayout.getVersion() == 2)
|
||||
return emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
|
||||
}
|
||||
llvm_unreachable("unsupported emitBaseIndexForLayout");
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
||||
return emitOffsetForBlockedLayout(blockedLayout, shape);
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersion() == 1)
|
||||
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
|
||||
if (mmaLayout.getVersion() == 2)
|
||||
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
|
||||
}
|
||||
llvm_unreachable("unsupported emitOffsetForLayout");
|
||||
}
|
||||
|
||||
// Emit indices calculation within each ConversionPattern, and returns a
|
||||
// [elemsPerThread X rank] index matrix.
|
||||
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
||||
// be eliminated in the consequent MLIR/LLVM optimization. We might
|
||||
// implement an indexCache if necessary.
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const BlockedEncodingAttr &blockedLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
|
||||
// TODO: [phil] redundant indices commputation do not appear to hurt
|
||||
// performance much, but they could still significantly slow down
|
||||
// computations.
|
||||
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Attribute &layout, ArrayRef<int64_t> shape) const {
|
||||
|
||||
// step 1, delinearize threadId to get the base index
|
||||
auto multiDimBase =
|
||||
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||
|
||||
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape);
|
||||
// step 2, get offset of each element
|
||||
SmallVector<SmallVector<unsigned>> offset =
|
||||
emitOffsetForBlockedLayout(blockedLayout, shape);
|
||||
|
||||
auto offset = emitOffsetForLayout(layout, shape);
|
||||
// step 3, add offset to base, and reorder the sequence of indices to
|
||||
// guarantee that elems in the same sizePerThread are adjacent in order
|
||||
unsigned rank = shape.size();
|
||||
@@ -726,6 +774,49 @@ public:
|
||||
return multiDimIdx;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SliceEncodingAttr &sliceLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto parent = sliceLayout.getParent();
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
size_t rank = shape.size();
|
||||
auto paddedIndices =
|
||||
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
|
||||
unsigned numIndices = paddedIndices.size();
|
||||
SmallVector<SmallVector<Value>> resultIndices(numIndices);
|
||||
for (unsigned i = 0; i < numIndices; ++i)
|
||||
for (unsigned d = 0; d < rank + 1; ++d)
|
||||
if (d != dim)
|
||||
resultIndices[i].push_back(paddedIndices[i][d]);
|
||||
|
||||
return resultIndices;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Emit indices
|
||||
// -----------------------------------------------------------------------
|
||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||
ConversionPatternRewriter &b,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return emitIndicesForDistributedLayout(loc, b, blocked, shape);
|
||||
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return emitIndicesForDistributedLayout(loc, b, mma, shape);
|
||||
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
return emitIndicesForSliceLayout(loc, b, slice, shape);
|
||||
} else {
|
||||
assert(0 && "emitIndices for layouts other than blocked & slice not "
|
||||
"implemented yet");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Shared memory utilities
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
template <typename T>
|
||||
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
||||
T value) const {
|
||||
@@ -954,8 +1045,8 @@ struct LoadOpConversion
|
||||
|
||||
// Determine the vectorization size
|
||||
Type valueTy = op.getResult().getType();
|
||||
Type valueElemTy = typeConverter->convertType(
|
||||
getElementTypeOrSelf(valueTy));
|
||||
Type valueElemTy =
|
||||
typeConverter->convertType(getElementTypeOrSelf(valueTy));
|
||||
unsigned vec = getVectorSize(ptr);
|
||||
unsigned numElems = getElemsPerThread(ptr.getType());
|
||||
if (llMask)
|
||||
@@ -1086,7 +1177,8 @@ struct LoadOpConversion
|
||||
: retTys[0];
|
||||
|
||||
// TODO: if (has_l2_evict_policy)
|
||||
// auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||
// auto asmDialectAttr =
|
||||
// LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||
// LLVM::AsmDialect::AD_ATT);
|
||||
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
|
||||
|
||||
@@ -1149,8 +1241,8 @@ struct StoreOpConversion
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
auto valueTy = value.getType();
|
||||
Type valueElemTy = typeConverter->convertType(
|
||||
getElementTypeOrSelf(valueTy));
|
||||
Type valueElemTy =
|
||||
typeConverter->convertType(getElementTypeOrSelf(valueTy));
|
||||
|
||||
unsigned vec = getVectorSize(ptr);
|
||||
unsigned numElems = getElemsPerThread(ptr.getType());
|
||||
@@ -1571,19 +1663,19 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
unsigned axis = adaptor.axis();
|
||||
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto srcRank = srcTy.getRank();
|
||||
|
||||
auto threadsPerWarp = srcLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = srcLayout.getWarpsPerCTA();
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
|
||||
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout);
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
|
||||
auto order = srcLayout.getOrder();
|
||||
auto order = getOrder(srcLayout);
|
||||
unsigned sizeIntraWarps = threadsPerWarp[axis];
|
||||
unsigned sizeInterWarps = warpsPerCTA[axis];
|
||||
|
||||
@@ -1592,7 +1684,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
||||
|
||||
SmallVector<SmallVector<unsigned>> offset =
|
||||
emitOffsetForBlockedLayout(srcLayout, srcShape);
|
||||
emitOffsetForLayout(srcLayout, srcShape);
|
||||
|
||||
std::map<SmallVector<unsigned>, Value> accs;
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||
@@ -1655,7 +1747,8 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
// each thread needs to process:
|
||||
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
|
||||
unsigned elems = product<unsigned>(smemShape);
|
||||
unsigned numThreads = product<unsigned>(srcLayout.getWarpsPerCTA()) * 32;
|
||||
unsigned numThreads =
|
||||
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) * 32;
|
||||
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
|
||||
Value readOffset = threadId;
|
||||
for (unsigned round = 0; round < elemsPerThread; ++round) {
|
||||
@@ -2104,9 +2197,10 @@ struct FpToFpOpConversion
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
static SmallVector<Value> convertFp8x4ToFp16x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
static SmallVector<Value>
|
||||
convertFp8x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto ctx = rewriter.getContext();
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value fp8x4Vec = undef(fp8x4VecTy);
|
||||
@@ -2117,18 +2211,17 @@ struct FpToFpOpConversion
|
||||
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptxAsm =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n"
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n"
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"shr.b32 b0, b0, 1; \n"
|
||||
"shr.b32 b1, b1, 1; \n"
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
|
||||
"}";
|
||||
auto *ptxAsm = "{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n"
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n"
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"shr.b32 b0, b0, 1; \n"
|
||||
"shr.b32 b1, b1, 1; \n"
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
|
||||
"}";
|
||||
auto &call = *builder.create(ptxAsm);
|
||||
|
||||
auto *o0 = builder.newOperand("=r");
|
||||
@@ -2141,21 +2234,20 @@ struct FpToFpOpConversion
|
||||
struct_ty(SmallVector<Type>{fp16x2VecTy, fp16x2VecTy});
|
||||
auto fp16x2x2Struct =
|
||||
builder.launch(rewriter, loc, fp16x2x2StructTy, false);
|
||||
auto fp16x2Vec0 = extract_val(fp16x2VecTy, fp16x2x2Struct,
|
||||
rewriter.getI32ArrayAttr({0}));
|
||||
auto fp16x2Vec1 = extract_val(fp16x2VecTy, fp16x2x2Struct,
|
||||
rewriter.getI32ArrayAttr({1}));
|
||||
return {
|
||||
extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
|
||||
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
|
||||
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
|
||||
extract_element(f16_ty, fp16x2Vec1, i32_val(1))
|
||||
};
|
||||
auto fp16x2Vec0 =
|
||||
extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({0}));
|
||||
auto fp16x2Vec1 =
|
||||
extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({1}));
|
||||
return {extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
|
||||
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
|
||||
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
|
||||
extract_element(f16_ty, fp16x2Vec1, i32_val(1))};
|
||||
}
|
||||
|
||||
static SmallVector<Value> convertFp16x4ToFp8x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
static SmallVector<Value>
|
||||
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto ctx = rewriter.getContext();
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
||||
@@ -2168,19 +2260,18 @@ struct FpToFpOpConversion
|
||||
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptxAsm =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
"shl.b32 a0, $1, 1; \n"
|
||||
"shl.b32 a1, $2, 1; \n"
|
||||
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"add.u32 a0, a0, 0x00800080; \n"
|
||||
"add.u32 a1, a1, 0x00800080; \n"
|
||||
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
|
||||
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
|
||||
"prmt.b32 $0, b0, b1, 0x7531; \n"
|
||||
"}";
|
||||
auto *ptxAsm = "{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
"shl.b32 a0, $1, 1; \n"
|
||||
"shl.b32 a1, $2, 1; \n"
|
||||
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"add.u32 a0, a0, 0x00800080; \n"
|
||||
"add.u32 a1, a1, 0x00800080; \n"
|
||||
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
|
||||
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
|
||||
"prmt.b32 $0, b0, b1, 0x7531; \n"
|
||||
"}";
|
||||
auto &call = *builder.create(ptxAsm);
|
||||
|
||||
auto *o = builder.newOperand("=r");
|
||||
@@ -2190,17 +2281,16 @@ struct FpToFpOpConversion
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
|
||||
return {
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(3))
|
||||
};
|
||||
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
|
||||
}
|
||||
|
||||
static SmallVector<Value> convertFp8x4ToBf16x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
static SmallVector<Value>
|
||||
convertFp8x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto ctx = rewriter.getContext();
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value fp8x4Vec = undef(fp8x4VecTy);
|
||||
@@ -2211,22 +2301,21 @@ struct FpToFpOpConversion
|
||||
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptxAsm =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n"
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n"
|
||||
"and.b32 sign0, a0, 0x80008000; \n"
|
||||
"and.b32 sign1, a1, 0x80008000; \n"
|
||||
"and.b32 nosign0, a0, 0x7fff7fff; \n"
|
||||
"and.b32 nosign1, a1, 0x7fff7fff; \n"
|
||||
"shr.b32 nosign0, nosign0, 4; \n"
|
||||
"shr.b32 nosign1, nosign1, 4; \n"
|
||||
"add.u32 nosign0, nosign0, 0x38003800; \n"
|
||||
"add.u32 nosign1, nosign1, 0x38003800; \n"
|
||||
"or.b32 $0, sign0, nosign0; \n"
|
||||
"or.b32 $1, sign1, nosign1; \n"
|
||||
"}";
|
||||
auto *ptxAsm = "{ \n"
|
||||
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n"
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n"
|
||||
"and.b32 sign0, a0, 0x80008000; \n"
|
||||
"and.b32 sign1, a1, 0x80008000; \n"
|
||||
"and.b32 nosign0, a0, 0x7fff7fff; \n"
|
||||
"and.b32 nosign1, a1, 0x7fff7fff; \n"
|
||||
"shr.b32 nosign0, nosign0, 4; \n"
|
||||
"shr.b32 nosign1, nosign1, 4; \n"
|
||||
"add.u32 nosign0, nosign0, 0x38003800; \n"
|
||||
"add.u32 nosign1, nosign1, 0x38003800; \n"
|
||||
"or.b32 $0, sign0, nosign0; \n"
|
||||
"or.b32 $1, sign1, nosign1; \n"
|
||||
"}";
|
||||
auto &call = *builder.create(ptxAsm);
|
||||
|
||||
auto *o0 = builder.newOperand("=r");
|
||||
@@ -2239,21 +2328,20 @@ struct FpToFpOpConversion
|
||||
struct_ty(SmallVector<Type>{bf16x2VecTy, bf16x2VecTy});
|
||||
auto bf16x2x2Struct =
|
||||
builder.launch(rewriter, loc, bf16x2x2StructTy, false);
|
||||
auto bf16x2Vec0 = extract_val(bf16x2VecTy, bf16x2x2Struct,
|
||||
rewriter.getI32ArrayAttr({0}));
|
||||
auto bf16x2Vec1 = extract_val(bf16x2VecTy, bf16x2x2Struct,
|
||||
rewriter.getI32ArrayAttr({1}));
|
||||
return {
|
||||
extract_element(bf16_ty, bf16x2Vec0, i32_val(0)),
|
||||
extract_element(bf16_ty, bf16x2Vec0, i32_val(1)),
|
||||
extract_element(bf16_ty, bf16x2Vec1, i32_val(0)),
|
||||
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))
|
||||
};
|
||||
auto bf16x2Vec0 =
|
||||
extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({0}));
|
||||
auto bf16x2Vec1 =
|
||||
extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({1}));
|
||||
return {extract_element(bf16_ty, bf16x2Vec0, i32_val(0)),
|
||||
extract_element(bf16_ty, bf16x2Vec0, i32_val(1)),
|
||||
extract_element(bf16_ty, bf16x2Vec1, i32_val(0)),
|
||||
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))};
|
||||
}
|
||||
|
||||
static SmallVector<Value> convertBf16x4ToFp8x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
static SmallVector<Value>
|
||||
convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto ctx = rewriter.getContext();
|
||||
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
|
||||
Value bf16x2Vec0 = undef(bf16x2VecTy);
|
||||
@@ -2266,43 +2354,42 @@ struct FpToFpOpConversion
|
||||
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptxAsm =
|
||||
"{ \n"
|
||||
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n"
|
||||
".reg .u32 fp8_min, fp8_max, rn_, zero; \n"
|
||||
"mov.u32 fp8_min, 0x38003800; \n"
|
||||
"mov.u32 fp8_max, 0x3ff03ff0; \n"
|
||||
"mov.u32 rn_, 0x80008; \n"
|
||||
"mov.u32 zero, 0; \n"
|
||||
"and.b32 sign0, $1, 0x80008000; \n"
|
||||
"and.b32 sign1, $2, 0x80008000; \n"
|
||||
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
|
||||
"and.b32 nosign0, $1, 0x7fff7fff; \n"
|
||||
"and.b32 nosign1, $2, 0x7fff7fff; \n"
|
||||
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
|
||||
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
|
||||
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
|
||||
"min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n"
|
||||
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
|
||||
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
|
||||
"min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n"
|
||||
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
|
||||
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
|
||||
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
|
||||
"min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n"
|
||||
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
|
||||
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
|
||||
"min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n"
|
||||
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
|
||||
"add.u32 nosign0, nosign0, rn_; \n"
|
||||
"add.u32 nosign1, nosign1, rn_; \n"
|
||||
"sub.u32 nosign0, nosign0, 0x38003800; \n"
|
||||
"sub.u32 nosign1, nosign1, 0x38003800; \n"
|
||||
"shr.u32 nosign0, nosign0, 4; \n"
|
||||
"shr.u32 nosign1, nosign1, 4; \n"
|
||||
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n"
|
||||
"or.b32 $0, nosign, sign; \n"
|
||||
"}";
|
||||
auto *ptxAsm = "{ \n"
|
||||
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n"
|
||||
".reg .u32 fp8_min, fp8_max, rn_, zero; \n"
|
||||
"mov.u32 fp8_min, 0x38003800; \n"
|
||||
"mov.u32 fp8_max, 0x3ff03ff0; \n"
|
||||
"mov.u32 rn_, 0x80008; \n"
|
||||
"mov.u32 zero, 0; \n"
|
||||
"and.b32 sign0, $1, 0x80008000; \n"
|
||||
"and.b32 sign1, $2, 0x80008000; \n"
|
||||
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
|
||||
"and.b32 nosign0, $1, 0x7fff7fff; \n"
|
||||
"and.b32 nosign1, $2, 0x7fff7fff; \n"
|
||||
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
|
||||
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
|
||||
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
|
||||
"min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n"
|
||||
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
|
||||
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
|
||||
"min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n"
|
||||
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
|
||||
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
|
||||
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
|
||||
"min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n"
|
||||
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
|
||||
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
|
||||
"min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n"
|
||||
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
|
||||
"add.u32 nosign0, nosign0, rn_; \n"
|
||||
"add.u32 nosign1, nosign1, rn_; \n"
|
||||
"sub.u32 nosign0, nosign0, 0x38003800; \n"
|
||||
"sub.u32 nosign1, nosign1, 0x38003800; \n"
|
||||
"shr.u32 nosign0, nosign0, 4; \n"
|
||||
"shr.u32 nosign1, nosign1, 4; \n"
|
||||
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n"
|
||||
"or.b32 $0, nosign, sign; \n"
|
||||
"}";
|
||||
auto &call = *builder.create(ptxAsm);
|
||||
|
||||
auto *o = builder.newOperand("=r");
|
||||
@@ -2312,51 +2399,49 @@ struct FpToFpOpConversion
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
|
||||
return {
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(3))
|
||||
};
|
||||
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
|
||||
}
|
||||
|
||||
static SmallVector<Value> convertFp8x4ToFp32x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
static SmallVector<Value>
|
||||
convertFp8x4ToFp32x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
|
||||
return {
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[1]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[2]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])
|
||||
};
|
||||
return {rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[1]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[2]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])};
|
||||
}
|
||||
|
||||
static SmallVector<Value> convertFp32x4ToFp8x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
|
||||
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
|
||||
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
|
||||
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
|
||||
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
|
||||
static SmallVector<Value>
|
||||
convertFp32x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
|
||||
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
|
||||
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
|
||||
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
|
||||
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
|
||||
}
|
||||
|
||||
static SmallVector<Value> convertFp8x4ToFp64x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
static SmallVector<Value>
|
||||
convertFp8x4ToFp64x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
|
||||
return {
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[1]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[2]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[3])
|
||||
};
|
||||
return {rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[1]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[2]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[3])};
|
||||
}
|
||||
|
||||
static SmallVector<Value> convertFp64x4ToFp8x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
static SmallVector<Value>
|
||||
convertFp64x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
|
||||
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
|
||||
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
|
||||
@@ -2379,9 +2464,10 @@ struct FpToFpOpConversion
|
||||
this->getTypeConverter()->convertType(dstEltType);
|
||||
|
||||
// Select convertor
|
||||
std::function<SmallVector<Value>(Location, ConversionPatternRewriter&,
|
||||
const Value&, const Value&,
|
||||
const Value&, const Value&)> convertor;
|
||||
std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
|
||||
const Value &, const Value &,
|
||||
const Value &, const Value &)>
|
||||
convertor;
|
||||
if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF16()) {
|
||||
convertor = convertFp8x4ToFp16x4;
|
||||
} else if (srcEltType.isF16() && dstEltType.isa<triton::Float8Type>()) {
|
||||
@@ -2410,8 +2496,7 @@ struct FpToFpOpConversion
|
||||
auto elements = getElementsFromStruct(loc, adaptor.from(), rewriter);
|
||||
SmallVector<Value> resultVals;
|
||||
for (size_t i = 0; i < elems; i += 4) {
|
||||
auto converted = convertor(loc, rewriter,
|
||||
elements[i], elements[i + 1],
|
||||
auto converted = convertor(loc, rewriter, elements[i], elements[i + 1],
|
||||
elements[i + 2], elements[i + 3]);
|
||||
resultVals.append(converted);
|
||||
}
|
||||
@@ -2626,6 +2711,82 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
unsigned elemId, ArrayRef<int64_t> shape,
|
||||
ArrayRef<unsigned> multiDimCTAInRepId,
|
||||
ArrayRef<unsigned> shapePerCTA) const {
|
||||
unsigned rank = shape.size();
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
auto multiDimOffsetFirstElem =
|
||||
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
SmallVector<unsigned> multiDimElemId =
|
||||
getMultiDimIndex<unsigned>(elemId, blockedLayout.getSizePerThread());
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
|
||||
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
||||
multiDimElemId[d]));
|
||||
}
|
||||
return multiDimOffset;
|
||||
}
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto multiDimOffsetParent =
|
||||
getMultiDimOffset(sliceLayout.getParent(), loc, rewriter, elemId,
|
||||
sliceLayout.paddedShape(shape),
|
||||
sliceLayout.paddedShape(multiDimCTAInRepId),
|
||||
sliceLayout.paddedShape(shapePerCTA));
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d == dim)
|
||||
continue;
|
||||
unsigned slicedD = d < dim ? d : (d - 1);
|
||||
multiDimOffset[slicedD] = multiDimOffsetParent[d];
|
||||
}
|
||||
return multiDimOffset;
|
||||
}
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
SmallVector<Value> mmaColIdx(2);
|
||||
SmallVector<Value> mmaRowIdx(2);
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = idx_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
// auto multiDimWarpId =
|
||||
// delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA());
|
||||
// TODO: double confirm if its document bug or DotConversion's Bug
|
||||
SmallVector<Value> multiDimWarpId(2);
|
||||
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
Value four = idx_val(4);
|
||||
Value mmaGrpId = udiv(laneId, four);
|
||||
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8));
|
||||
Value mmaThreadIdInGrp = urem(laneId, four);
|
||||
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2));
|
||||
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1));
|
||||
Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16));
|
||||
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
|
||||
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
|
||||
Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8));
|
||||
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
|
||||
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
|
||||
|
||||
assert(rank == 2);
|
||||
assert(mmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout ver1 not implemented yet");
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
|
||||
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
|
||||
multiDimOffset[0] = add(multiDimOffset[0],
|
||||
idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
||||
multiDimOffset[1] = add(multiDimOffset[1],
|
||||
idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
||||
return multiDimOffset;
|
||||
}
|
||||
llvm_unreachable("unexpected layout in getMultiDimOffset");
|
||||
}
|
||||
|
||||
// shared memory rd/st for blocked or mma layout with data padding
|
||||
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
|
||||
bool stNotRd, RankedTensorType type,
|
||||
@@ -2693,47 +2854,6 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
elemTy = IntegerType::get(elemTy.getContext(), 8);
|
||||
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
|
||||
|
||||
SmallVector<Value> multiDimOffsetFirstElem;
|
||||
SmallVector<Value> mmaColIdx(2);
|
||||
SmallVector<Value> mmaRowIdx(2);
|
||||
if (blockedLayout) {
|
||||
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
||||
loc, rewriter, blockedLayout, type.getShape());
|
||||
} else if (sliceLayout) {
|
||||
auto parent = sliceLayout.getParent();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
SmallVector<int64_t> paddedShape =
|
||||
sliceLayout.paddedShape(type.getShape());
|
||||
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
||||
loc, rewriter, blockedParent, paddedShape);
|
||||
} else {
|
||||
assert(0 && "SliceEncodingAttr with parent other than "
|
||||
"BlockedEncodingAttr not implemented");
|
||||
}
|
||||
} else if (mmaLayout) {
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = idx_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
// auto multiDimWarpId =
|
||||
// delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA());
|
||||
// TODO: double confirm if its document bug or DotConversion's Bug
|
||||
SmallVector<Value> multiDimWarpId(2);
|
||||
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
Value four = idx_val(4);
|
||||
Value mmaGrpId = udiv(laneId, four);
|
||||
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8));
|
||||
Value mmaThreadIdInGrp = urem(laneId, four);
|
||||
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2));
|
||||
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1));
|
||||
Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16));
|
||||
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
|
||||
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
|
||||
Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8));
|
||||
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
|
||||
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
|
||||
}
|
||||
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
||||
auto multiDimCTAInRepId = getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
|
||||
SmallVector<unsigned> multiDimCTAId(rank);
|
||||
@@ -2747,48 +2867,9 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
// consider of caching the index calculation result in case
|
||||
// of performance issue observed.
|
||||
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
if (blockedLayout) {
|
||||
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
||||
elemId, blockedLayout.getSizePerThread());
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
multiDimOffset[d] =
|
||||
add(multiDimOffsetFirstElem[d],
|
||||
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
||||
multiDimElemId[d]));
|
||||
}
|
||||
} else if (sliceLayout) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parent = sliceLayout.getParent();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
||||
elemId, blockedParent.getSizePerThread());
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d == dim)
|
||||
continue;
|
||||
unsigned slicedD = d < dim ? d : (d - 1);
|
||||
multiDimOffset[slicedD] =
|
||||
add(multiDimOffsetFirstElem[d],
|
||||
idx_val(multiDimCTAInRepId[slicedD] * shapePerCTA[slicedD] +
|
||||
multiDimElemId[d]));
|
||||
}
|
||||
} else {
|
||||
assert(0 && "SliceEncodingAttr with parent other than "
|
||||
"BlockedEncodingAttr not implemented");
|
||||
}
|
||||
} else if (mmaLayout) {
|
||||
assert(rank == 2);
|
||||
assert(mmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout ver1 not implemented yet");
|
||||
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
|
||||
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
|
||||
multiDimOffset[0] = add(
|
||||
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
||||
multiDimOffset[1] = add(
|
||||
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
||||
} else {
|
||||
assert(0 && "unexpected layout in processReplica");
|
||||
}
|
||||
SmallVector<Value> multiDimOffset =
|
||||
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
|
||||
multiDimCTAInRepId, shapePerCTA);
|
||||
Value offset =
|
||||
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
|
||||
reorder<unsigned>(paddedRepShape, outOrd));
|
||||
@@ -3637,10 +3718,10 @@ struct DotOpMmaV1ConversionHelper {
|
||||
// Compute the offset of the matrix to load.
|
||||
// Returns offsetAM, offsetAK, offsetBN, offsetBK.
|
||||
// NOTE, the information M(from $a) and N(from $b) couldn't be retrieved at
|
||||
// the same time in the usage in convert_layout[shared->dot_op], we leave the
|
||||
// noexist info to be 0 and only use the desired argument from the composed
|
||||
// result. In this way we want to retain the original code structure in
|
||||
// convert_mma884 method for easier debugging.
|
||||
// the same time in the usage in convert_layout[shared->dot_op], we leave
|
||||
// the noexist info to be 0 and only use the desired argument from the
|
||||
// composed result. In this way we want to retain the original code
|
||||
// structure in convert_mma884 method for easier debugging.
|
||||
std::tuple<Value, Value, Value, Value>
|
||||
computeOffsets(Value threadId, bool isARow, bool isBRow, ArrayRef<int> fpw,
|
||||
ArrayRef<int> spw, ArrayRef<int> rep,
|
||||
@@ -5231,8 +5312,8 @@ struct InsertSliceAsyncOpConversion
|
||||
srcStrides.emplace_back(smemObj.strides[i]);
|
||||
}
|
||||
}
|
||||
// Compute the offset based on the original dimensions of the shared memory
|
||||
// object
|
||||
// Compute the offset based on the original dimensions of the shared
|
||||
// memory object
|
||||
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
|
||||
auto dstPtrTy =
|
||||
ptr_ty(getTypeConverter()->convertType(resTy.getElementType()), 3);
|
||||
|
@@ -42,15 +42,11 @@ static Type getPointeeType(Type type) {
|
||||
|
||||
namespace gpu {
|
||||
|
||||
// TODO: Inheritation of layout attributes
|
||||
unsigned getElemsPerThread(Type type) {
|
||||
if (type.isIntOrIndexOrFloat() ||
|
||||
type.isa<triton::Float8Type>() ||
|
||||
type.isa<triton::PointerType>())
|
||||
return 1;
|
||||
auto tensorType = type.cast<RankedTensorType>();
|
||||
auto layout = tensorType.getEncoding();
|
||||
auto shape = tensorType.getShape();
|
||||
// TODO: Inheritance of layout attributes
|
||||
// so that all distributed layouts implement
|
||||
// these utilities
|
||||
|
||||
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getElemsPerThread(shape);
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
@@ -67,6 +63,43 @@ unsigned getElemsPerThread(Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
unsigned getElemsPerThread(Type type) {
|
||||
if (type.isIntOrIndexOrFloat() ||
|
||||
type.isa<triton::Float8Type>() ||
|
||||
type.isa<triton::PointerType>())
|
||||
return 1;
|
||||
auto tensorType = type.cast<RankedTensorType>();
|
||||
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape());
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getThreadsPerWarp().begin(),
|
||||
blockedLayout.getThreadsPerWarp().end());
|
||||
}
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersion() == 1)
|
||||
return SmallVector<unsigned>{4, 8};
|
||||
if (mmaLayout.getVersion() == 2)
|
||||
return SmallVector<unsigned>{8, 4};
|
||||
}
|
||||
assert(0 && "getThreadsPerWarp not implemented");
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
|
||||
blockedLayout.getWarpsPerCTA().end());
|
||||
}
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(mmaLayout.getWarpsPerCTA().begin(),
|
||||
mmaLayout.getWarpsPerCTA().end());
|
||||
}
|
||||
assert(0 && "getWarpsPerCTA not implemented");
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
||||
@@ -129,17 +162,11 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parent = sliceLayout.getParent();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
for (unsigned d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
|
||||
if (d == dim)
|
||||
continue;
|
||||
shape.push_back(blockedParent.getSizePerThread()[d] *
|
||||
blockedParent.getThreadsPerWarp()[d] *
|
||||
blockedParent.getWarpsPerCTA()[d]);
|
||||
}
|
||||
} else {
|
||||
assert(0 && "SliceEncodingAttr with parent other than "
|
||||
"BlockedEncodingAttr not implemented");
|
||||
for (unsigned d = 0, n = getOrder(parent).size(); d < n; ++d) {
|
||||
if (d == dim)
|
||||
continue;
|
||||
shape.push_back(getSizePerThread(parent)[d] *
|
||||
getThreadsPerWarp(parent)[d] * getWarpsPerCTA(parent)[d]);
|
||||
}
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersion() == 2)
|
||||
@@ -289,11 +316,11 @@ unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
return product<unsigned>(elemsPerThread);
|
||||
}
|
||||
|
||||
SmallVector<int64_t>
|
||||
SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
|
||||
template <class T>
|
||||
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
|
||||
size_t rank = shape.size();
|
||||
unsigned dim = getDim();
|
||||
SmallVector<int64_t> retShape(rank + 1);
|
||||
SmallVector<T> retShape(rank + 1);
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d < dim)
|
||||
retShape[d] = shape[d];
|
||||
@@ -304,18 +331,15 @@ SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
|
||||
}
|
||||
return retShape;
|
||||
}
|
||||
template SmallVector<unsigned>
|
||||
SliceEncodingAttr::paddedShape<unsigned>(ArrayRef<unsigned> shape) const;
|
||||
template SmallVector<int64_t>
|
||||
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
|
||||
|
||||
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
size_t rank = shape.size();
|
||||
auto parent = getParent();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
|
||||
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
|
||||
return blockedParent.getElemsPerThread(paddedShape(shape));
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
return ::getElemsPerThread(parent, paddedShape(shape));
|
||||
}
|
||||
|
||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
|
Reference in New Issue
Block a user