[BACKEND] Added support for mma layouts in reductions (#863)

Validated hackily by manually modifying the reduction .ttgir in my local
cache. There will be a follow-up PR adding some better testing
infrastructure to test out conversions and reductions on arbitrary
layouts.
This commit is contained in:
Philippe Tillet
2022-11-10 09:58:07 -08:00
committed by GitHub
parent 57fd1864a7
commit 2aa538ec2e
6 changed files with 469 additions and 365 deletions

View File

@@ -25,6 +25,10 @@ namespace gpu {
unsigned getElemsPerThread(Type type);
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
SmallVector<unsigned> getSizePerThread(Attribute layout);
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);

View File

@@ -326,7 +326,8 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
);
let extraClassDeclaration = extraBaseClassDeclaration # [{
SmallVector<int64_t> paddedShape(ArrayRef<int64_t> shape) const;
template<class T>
SmallVector<T> paddedShape(ArrayRef<T> shape) const;
}];
}

View File

@@ -87,22 +87,22 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.axis();
bool fastReduce = axis == srcLayout.getOrder()[0];
bool fastReduce = axis == getOrder(srcLayout)[0];
SmallVector<unsigned> smemShape;
for (auto d : srcShape)
smemShape.push_back(d);
if (fastReduce) {
unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis];
unsigned sizeInterWarps = gpu::getWarpsPerCTA(srcLayout)[axis];
smemShape[axis] = sizeInterWarps;
} else {
unsigned threadsPerCTAAxis =
srcLayout.getThreadsPerWarp()[axis] * srcLayout.getWarpsPerCTA()[axis];
unsigned threadsPerCTAAxis = gpu::getThreadsPerWarp(srcLayout)[axis] *
gpu::getWarpsPerCTA(srcLayout)[axis];
smemShape[axis] = threadsPerCTAAxis;
}
@@ -161,16 +161,11 @@ private:
// TODO(Keren): Reduce with index is not supported yet.
auto value = op->getOperand(0);
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<BlockedEncodingAttr>()) {
auto smemShape = getScratchConfigForReduce(reduceOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(),
1, std::multiplies{});
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else {
assert(0 && "ReduceOp with input layout other than blocked layout is "
"not implemented yet");
}
}
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();

View File

@@ -494,6 +494,10 @@ public:
rewriter.getIntegerAttr(rewriter.getIndexType(), value));
}
// -----------------------------------------------------------------------
// Utilities
// -----------------------------------------------------------------------
// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
@@ -556,6 +560,10 @@ public:
return ret;
}
// -----------------------------------------------------------------------
// Blocked layout indices
// -----------------------------------------------------------------------
// Get an index-base for each dimension for a \param blocked_layout.
SmallVector<Value>
emitBaseIndexForBlockedLayout(Location loc,
@@ -599,52 +607,6 @@ public:
return multiDimBase;
}
SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
return emitIndicesForBlockedLayout(loc, b, blocked, shape);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
return emitIndicesForSliceLayout(loc, b, slice, shape);
} else {
assert(0 && "emitIndices for layouts other than blocked & slice not "
"implemented yet");
return {};
}
}
SmallVector<SmallVector<Value>>
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
const SliceEncodingAttr &sliceLayout,
ArrayRef<int64_t> shape) const {
auto parent = sliceLayout.getParent();
unsigned dim = sliceLayout.getDim();
size_t rank = shape.size();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
auto paddedIndices = emitIndicesForBlockedLayout(
loc, rewriter, blockedParent, sliceLayout.paddedShape(shape));
unsigned numIndices = paddedIndices.size();
SmallVector<SmallVector<Value>> resultIndices(numIndices);
for (unsigned i = 0; i < numIndices; ++i)
for (unsigned d = 0; d < rank + 1; ++d)
if (d != dim)
resultIndices[i].push_back(paddedIndices[i][d]);
return resultIndices;
} else if (auto sliceParent = parent.dyn_cast<SliceEncodingAttr>()) {
assert(0 && "emitIndicesForSliceLayout with parent of sliceLayout"
"is not implemented yet");
return {};
} else {
assert(0 && "emitIndicesForSliceLayout with parent other than blocked & "
"slice not implemented yet");
return {};
}
}
SmallVector<SmallVector<unsigned>>
emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
ArrayRef<int64_t> shape) const {
@@ -696,23 +658,109 @@ public:
return reorderedOffset;
}
// -----------------------------------------------------------------------
// Mma layout indices
// -----------------------------------------------------------------------
SmallVector<Value>
emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
llvm_unreachable("emitIndicesForMmaLayoutV1 not implemented");
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMmaLayoutV1(const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
llvm_unreachable("emitOffsetForMmaLayoutV1 not implemented");
}
SmallVector<Value>
emitBaseIndexForMmaLayoutV2(Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
assert(_warpsPerCTA.size() == 2);
SmallVector<Value> warpsPerCTA = {idx_val(_warpsPerCTA[0]),
idx_val(_warpsPerCTA[1])};
Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
Value warpId0 = urem(warpId, warpsPerCTA[0]);
Value warpId1 = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]);
Value offWarp0 = mul(warpId0, idx_val(16));
Value offWarp1 = mul(warpId1, idx_val(8));
SmallVector<Value> multiDimBase(2);
multiDimBase[0] = add(udiv(laneId, idx_val(4)), offWarp0);
multiDimBase[1] = add(mul(idx_val(2), urem(laneId, idx_val(4))), offWarp1);
return multiDimBase;
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
SmallVector<SmallVector<unsigned>> ret;
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
ret.push_back({i, j});
ret.push_back({i, j + 1});
}
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
ret.push_back({i + 8, j});
ret.push_back({i + 8, j + 1});
}
}
return ret;
}
// -----------------------------------------------------------------------
// Get offsets / indices for any layout
// -----------------------------------------------------------------------
SmallVector<Value> emitBaseIndexForLayout(Location loc,
ConversionPatternRewriter &rewriter,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
return emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 1)
return emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
if (mmaLayout.getVersion() == 2)
return emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
}
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
SmallVector<SmallVector<unsigned>>
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
return emitOffsetForBlockedLayout(blockedLayout, shape);
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 1)
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
if (mmaLayout.getVersion() == 2)
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
}
llvm_unreachable("unsupported emitOffsetForLayout");
}
// Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix.
// TODO: [goostavz] Double confirm the redundant indices calculations will
// be eliminated in the consequent MLIR/LLVM optimization. We might
// implement an indexCache if necessary.
SmallVector<SmallVector<Value>>
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blockedLayout,
ArrayRef<int64_t> shape) const {
// TODO: [phil] redundant indices commputation do not appear to hurt
// performance much, but they could still significantly slow down
// computations.
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
Location loc, ConversionPatternRewriter &rewriter,
const Attribute &layout, ArrayRef<int64_t> shape) const {
// step 1, delinearize threadId to get the base index
auto multiDimBase =
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape);
// step 2, get offset of each element
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForBlockedLayout(blockedLayout, shape);
auto offset = emitOffsetForLayout(layout, shape);
// step 3, add offset to base, and reorder the sequence of indices to
// guarantee that elems in the same sizePerThread are adjacent in order
unsigned rank = shape.size();
@@ -726,6 +774,49 @@ public:
return multiDimIdx;
}
SmallVector<SmallVector<Value>>
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
const SliceEncodingAttr &sliceLayout,
ArrayRef<int64_t> shape) const {
auto parent = sliceLayout.getParent();
unsigned dim = sliceLayout.getDim();
size_t rank = shape.size();
auto paddedIndices =
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
unsigned numIndices = paddedIndices.size();
SmallVector<SmallVector<Value>> resultIndices(numIndices);
for (unsigned i = 0; i < numIndices; ++i)
for (unsigned d = 0; d < rank + 1; ++d)
if (d != dim)
resultIndices[i].push_back(paddedIndices[i][d]);
return resultIndices;
}
// -----------------------------------------------------------------------
// Emit indices
// -----------------------------------------------------------------------
SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
return emitIndicesForDistributedLayout(loc, b, blocked, shape);
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
return emitIndicesForDistributedLayout(loc, b, mma, shape);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
return emitIndicesForSliceLayout(loc, b, slice, shape);
} else {
assert(0 && "emitIndices for layouts other than blocked & slice not "
"implemented yet");
return {};
}
}
// -----------------------------------------------------------------------
// Shared memory utilities
// -----------------------------------------------------------------------
template <typename T>
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
T value) const {
@@ -954,8 +1045,8 @@ struct LoadOpConversion
// Determine the vectorization size
Type valueTy = op.getResult().getType();
Type valueElemTy = typeConverter->convertType(
getElementTypeOrSelf(valueTy));
Type valueElemTy =
typeConverter->convertType(getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
if (llMask)
@@ -1086,7 +1177,8 @@ struct LoadOpConversion
: retTys[0];
// TODO: if (has_l2_evict_policy)
// auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
// auto asmDialectAttr =
// LLVM::AsmDialectAttr::get(rewriter.getContext(),
// LLVM::AsmDialect::AD_ATT);
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
@@ -1149,8 +1241,8 @@ struct StoreOpConversion
MLIRContext *ctx = rewriter.getContext();
auto valueTy = value.getType();
Type valueElemTy = typeConverter->convertType(
getElementTypeOrSelf(valueTy));
Type valueElemTy =
typeConverter->convertType(getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
@@ -1571,19 +1663,19 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
unsigned axis = adaptor.axis();
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto srcRank = srcTy.getRank();
auto threadsPerWarp = srcLayout.getThreadsPerWarp();
auto warpsPerCTA = srcLayout.getWarpsPerCTA();
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout);
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);
auto order = srcLayout.getOrder();
auto order = getOrder(srcLayout);
unsigned sizeIntraWarps = threadsPerWarp[axis];
unsigned sizeInterWarps = warpsPerCTA[axis];
@@ -1592,7 +1684,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForBlockedLayout(srcLayout, srcShape);
emitOffsetForLayout(srcLayout, srcShape);
std::map<SmallVector<unsigned>, Value> accs;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
@@ -1655,7 +1747,8 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
// each thread needs to process:
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
unsigned elems = product<unsigned>(smemShape);
unsigned numThreads = product<unsigned>(srcLayout.getWarpsPerCTA()) * 32;
unsigned numThreads =
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) * 32;
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
Value readOffset = threadId;
for (unsigned round = 0; round < elemsPerThread; ++round) {
@@ -2104,9 +2197,10 @@ struct FpToFpOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
static SmallVector<Value> convertFp8x4ToFp16x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
static SmallVector<Value>
convertFp8x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
@@ -2117,8 +2211,7 @@ struct FpToFpOpConversion
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto *ptxAsm =
"{ \n"
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
@@ -2141,21 +2234,20 @@ struct FpToFpOpConversion
struct_ty(SmallVector<Type>{fp16x2VecTy, fp16x2VecTy});
auto fp16x2x2Struct =
builder.launch(rewriter, loc, fp16x2x2StructTy, false);
auto fp16x2Vec0 = extract_val(fp16x2VecTy, fp16x2x2Struct,
rewriter.getI32ArrayAttr({0}));
auto fp16x2Vec1 = extract_val(fp16x2VecTy, fp16x2x2Struct,
rewriter.getI32ArrayAttr({1}));
return {
extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
auto fp16x2Vec0 =
extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({0}));
auto fp16x2Vec1 =
extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({1}));
return {extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
extract_element(f16_ty, fp16x2Vec1, i32_val(1))
};
extract_element(f16_ty, fp16x2Vec1, i32_val(1))};
}
static SmallVector<Value> convertFp16x4ToFp8x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
static SmallVector<Value>
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
@@ -2168,8 +2260,7 @@ struct FpToFpOpConversion
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
PTXBuilder builder;
auto *ptxAsm =
"{ \n"
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"shl.b32 a0, $1, 1; \n"
"shl.b32 a1, $2, 1; \n"
@@ -2190,17 +2281,16 @@ struct FpToFpOpConversion
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {
extract_element(i8_ty, fp8x4Vec, i32_val(0)),
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))
};
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
}
static SmallVector<Value> convertFp8x4ToBf16x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
static SmallVector<Value>
convertFp8x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
@@ -2211,8 +2301,7 @@ struct FpToFpOpConversion
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto *ptxAsm =
"{ \n"
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
@@ -2239,21 +2328,20 @@ struct FpToFpOpConversion
struct_ty(SmallVector<Type>{bf16x2VecTy, bf16x2VecTy});
auto bf16x2x2Struct =
builder.launch(rewriter, loc, bf16x2x2StructTy, false);
auto bf16x2Vec0 = extract_val(bf16x2VecTy, bf16x2x2Struct,
rewriter.getI32ArrayAttr({0}));
auto bf16x2Vec1 = extract_val(bf16x2VecTy, bf16x2x2Struct,
rewriter.getI32ArrayAttr({1}));
return {
extract_element(bf16_ty, bf16x2Vec0, i32_val(0)),
auto bf16x2Vec0 =
extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({0}));
auto bf16x2Vec1 =
extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({1}));
return {extract_element(bf16_ty, bf16x2Vec0, i32_val(0)),
extract_element(bf16_ty, bf16x2Vec0, i32_val(1)),
extract_element(bf16_ty, bf16x2Vec1, i32_val(0)),
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))
};
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))};
}
static SmallVector<Value> convertBf16x4ToFp8x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
static SmallVector<Value>
convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
@@ -2266,8 +2354,7 @@ struct FpToFpOpConversion
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
PTXBuilder builder;
auto *ptxAsm =
"{ \n"
auto *ptxAsm = "{ \n"
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n"
".reg .u32 fp8_min, fp8_max, rn_, zero; \n"
"mov.u32 fp8_min, 0x38003800; \n"
@@ -2312,29 +2399,27 @@ struct FpToFpOpConversion
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {
extract_element(i8_ty, fp8x4Vec, i32_val(0)),
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))
};
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
}
static SmallVector<Value> convertFp8x4ToFp32x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
static SmallVector<Value>
convertFp8x4ToFp32x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
return {rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[1]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[2]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])
};
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])};
}
static SmallVector<Value> convertFp32x4ToFp8x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
static SmallVector<Value>
convertFp32x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
@@ -2342,21 +2427,21 @@ struct FpToFpOpConversion
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
}
static SmallVector<Value> convertFp8x4ToFp64x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
static SmallVector<Value>
convertFp8x4ToFp64x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
return {rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[1]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[2]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[3])
};
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[3])};
}
static SmallVector<Value> convertFp64x4ToFp8x4(
Location loc, ConversionPatternRewriter &rewriter,
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
static SmallVector<Value>
convertFp64x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
@@ -2381,7 +2466,8 @@ struct FpToFpOpConversion
// Select convertor
std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
const Value &, const Value &,
const Value&, const Value&)> convertor;
const Value &, const Value &)>
convertor;
if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF16()) {
convertor = convertFp8x4ToFp16x4;
} else if (srcEltType.isF16() && dstEltType.isa<triton::Float8Type>()) {
@@ -2410,8 +2496,7 @@ struct FpToFpOpConversion
auto elements = getElementsFromStruct(loc, adaptor.from(), rewriter);
SmallVector<Value> resultVals;
for (size_t i = 0; i < elems; i += 4) {
auto converted = convertor(loc, rewriter,
elements[i], elements[i + 1],
auto converted = convertor(loc, rewriter, elements[i], elements[i + 1],
elements[i + 2], elements[i + 3]);
resultVals.append(converted);
}
@@ -2626,6 +2711,82 @@ public:
}
private:
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
ConversionPatternRewriter &rewriter,
unsigned elemId, ArrayRef<int64_t> shape,
ArrayRef<unsigned> multiDimCTAInRepId,
ArrayRef<unsigned> shapePerCTA) const {
unsigned rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
auto multiDimOffsetFirstElem =
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
SmallVector<Value> multiDimOffset(rank);
SmallVector<unsigned> multiDimElemId =
getMultiDimIndex<unsigned>(elemId, blockedLayout.getSizePerThread());
for (unsigned d = 0; d < rank; ++d) {
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
multiDimElemId[d]));
}
return multiDimOffset;
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto multiDimOffsetParent =
getMultiDimOffset(sliceLayout.getParent(), loc, rewriter, elemId,
sliceLayout.paddedShape(shape),
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
SmallVector<Value> multiDimOffset(rank);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d == dim)
continue;
unsigned slicedD = d < dim ? d : (d - 1);
multiDimOffset[slicedD] = multiDimOffsetParent[d];
}
return multiDimOffset;
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
SmallVector<Value> mmaColIdx(2);
SmallVector<Value> mmaRowIdx(2);
Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// auto multiDimWarpId =
// delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA());
// TODO: double confirm if its document bug or DotConversion's Bug
SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
Value four = idx_val(4);
Value mmaGrpId = udiv(laneId, four);
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8));
Value mmaThreadIdInGrp = urem(laneId, four);
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2));
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1));
Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16));
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8));
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
assert(rank == 2);
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout ver1 not implemented yet");
SmallVector<Value> multiDimOffset(rank);
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[0] = add(multiDimOffset[0],
idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(multiDimOffset[1],
idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
return multiDimOffset;
}
llvm_unreachable("unexpected layout in getMultiDimOffset");
}
// shared memory rd/st for blocked or mma layout with data padding
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
bool stNotRd, RankedTensorType type,
@@ -2693,47 +2854,6 @@ void ConvertLayoutOpConversion::processReplica(
elemTy = IntegerType::get(elemTy.getContext(), 8);
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
SmallVector<Value> multiDimOffsetFirstElem;
SmallVector<Value> mmaColIdx(2);
SmallVector<Value> mmaRowIdx(2);
if (blockedLayout) {
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
loc, rewriter, blockedLayout, type.getShape());
} else if (sliceLayout) {
auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
SmallVector<int64_t> paddedShape =
sliceLayout.paddedShape(type.getShape());
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
loc, rewriter, blockedParent, paddedShape);
} else {
assert(0 && "SliceEncodingAttr with parent other than "
"BlockedEncodingAttr not implemented");
}
} else if (mmaLayout) {
Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// auto multiDimWarpId =
// delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA());
// TODO: double confirm if its document bug or DotConversion's Bug
SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
Value four = idx_val(4);
Value mmaGrpId = udiv(laneId, four);
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8));
Value mmaThreadIdInGrp = urem(laneId, four);
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2));
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1));
Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16));
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8));
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
}
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
auto multiDimCTAInRepId = getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
SmallVector<unsigned> multiDimCTAId(rank);
@@ -2747,48 +2867,9 @@ void ConvertLayoutOpConversion::processReplica(
// consider of caching the index calculation result in case
// of performance issue observed.
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
SmallVector<Value> multiDimOffset(rank);
if (blockedLayout) {
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, blockedLayout.getSizePerThread());
for (unsigned d = 0; d < rank; ++d) {
multiDimOffset[d] =
add(multiDimOffsetFirstElem[d],
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
multiDimElemId[d]));
}
} else if (sliceLayout) {
unsigned dim = sliceLayout.getDim();
auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, blockedParent.getSizePerThread());
for (unsigned d = 0; d < rank + 1; ++d) {
if (d == dim)
continue;
unsigned slicedD = d < dim ? d : (d - 1);
multiDimOffset[slicedD] =
add(multiDimOffsetFirstElem[d],
idx_val(multiDimCTAInRepId[slicedD] * shapePerCTA[slicedD] +
multiDimElemId[d]));
}
} else {
assert(0 && "SliceEncodingAttr with parent other than "
"BlockedEncodingAttr not implemented");
}
} else if (mmaLayout) {
assert(rank == 2);
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout ver1 not implemented yet");
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[0] = add(
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
} else {
assert(0 && "unexpected layout in processReplica");
}
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
multiDimCTAInRepId, shapePerCTA);
Value offset =
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
reorder<unsigned>(paddedRepShape, outOrd));
@@ -3637,10 +3718,10 @@ struct DotOpMmaV1ConversionHelper {
// Compute the offset of the matrix to load.
// Returns offsetAM, offsetAK, offsetBN, offsetBK.
// NOTE, the information M(from $a) and N(from $b) couldn't be retrieved at
// the same time in the usage in convert_layout[shared->dot_op], we leave the
// noexist info to be 0 and only use the desired argument from the composed
// result. In this way we want to retain the original code structure in
// convert_mma884 method for easier debugging.
// the same time in the usage in convert_layout[shared->dot_op], we leave
// the noexist info to be 0 and only use the desired argument from the
// composed result. In this way we want to retain the original code
// structure in convert_mma884 method for easier debugging.
std::tuple<Value, Value, Value, Value>
computeOffsets(Value threadId, bool isARow, bool isBRow, ArrayRef<int> fpw,
ArrayRef<int> spw, ArrayRef<int> rep,
@@ -5231,8 +5312,8 @@ struct InsertSliceAsyncOpConversion
srcStrides.emplace_back(smemObj.strides[i]);
}
}
// Compute the offset based on the original dimensions of the shared memory
// object
// Compute the offset based on the original dimensions of the shared
// memory object
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
auto dstPtrTy =
ptr_ty(getTypeConverter()->convertType(resTy.getElementType()), 3);

View File

@@ -42,15 +42,11 @@ static Type getPointeeType(Type type) {
namespace gpu {
// TODO: Inheritation of layout attributes
unsigned getElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() ||
type.isa<triton::Float8Type>() ||
type.isa<triton::PointerType>())
return 1;
auto tensorType = type.cast<RankedTensorType>();
auto layout = tensorType.getEncoding();
auto shape = tensorType.getShape();
// TODO: Inheritance of layout attributes
// so that all distributed layouts implement
// these utilities
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getElemsPerThread(shape);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
@@ -67,6 +63,43 @@ unsigned getElemsPerThread(Type type) {
}
}
unsigned getElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() ||
type.isa<triton::Float8Type>() ||
type.isa<triton::PointerType>())
return 1;
auto tensorType = type.cast<RankedTensorType>();
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape());
}
SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getThreadsPerWarp().begin(),
blockedLayout.getThreadsPerWarp().end());
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 1)
return SmallVector<unsigned>{4, 8};
if (mmaLayout.getVersion() == 2)
return SmallVector<unsigned>{8, 4};
}
assert(0 && "getThreadsPerWarp not implemented");
return {};
}
SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
blockedLayout.getWarpsPerCTA().end());
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return SmallVector<unsigned>(mmaLayout.getWarpsPerCTA().begin(),
mmaLayout.getWarpsPerCTA().end());
}
assert(0 && "getWarpsPerCTA not implemented");
return {};
}
SmallVector<unsigned> getSizePerThread(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
@@ -129,17 +162,11 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
for (unsigned d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
for (unsigned d = 0, n = getOrder(parent).size(); d < n; ++d) {
if (d == dim)
continue;
shape.push_back(blockedParent.getSizePerThread()[d] *
blockedParent.getThreadsPerWarp()[d] *
blockedParent.getWarpsPerCTA()[d]);
}
} else {
assert(0 && "SliceEncodingAttr with parent other than "
"BlockedEncodingAttr not implemented");
shape.push_back(getSizePerThread(parent)[d] *
getThreadsPerWarp(parent)[d] * getWarpsPerCTA(parent)[d]);
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 2)
@@ -289,11 +316,11 @@ unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
return product<unsigned>(elemsPerThread);
}
SmallVector<int64_t>
SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
template <class T>
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
size_t rank = shape.size();
unsigned dim = getDim();
SmallVector<int64_t> retShape(rank + 1);
SmallVector<T> retShape(rank + 1);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d < dim)
retShape[d] = shape[d];
@@ -304,18 +331,15 @@ SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
}
return retShape;
}
template SmallVector<unsigned>
SliceEncodingAttr::paddedShape<unsigned>(ArrayRef<unsigned> shape) const;
template SmallVector<int64_t>
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
auto parent = getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
return blockedParent.getElemsPerThread(paddedShape(shape));
} else {
assert(0 && "getElemsPerThread not implemented");
return 0;
}
return ::getElemsPerThread(parent, paddedShape(shape));
}
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {

View File

@@ -117,8 +117,7 @@ def test_reduce2d(op, dtype, shape, axis):
z = torch.empty(reduced_shape, device=x.device, dtype=reduced_dtype)
kernel = patch_kernel(reduce2d_kernel, {'OP': op})
grid = (1,)
kernel[grid](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1])
kernel[(1,)](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1])
if op == 'sum':
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype)