[Triton-MLIR][BACKEND] add convert_layout[shared->dot_op] converstion to adapt DotOperand layout (#786)
This PR helps to 1. Adapt the existing DotOp conversion to the design of the new DotOperand layout, 2. Making the DotOp conversion work with both shared-layout inputs case and dotoperand-layout inputs case for further upstream switch.
This commit is contained in:
@@ -52,7 +52,7 @@ different cuda threads in the programs, via shared memory. In other words,
|
|||||||
for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
|
for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
|
||||||
|
|
||||||
In order to avoid shared memory bank conflicts, elements may be swizzled
|
In order to avoid shared memory bank conflicts, elements may be swizzled
|
||||||
in memory. For example, a swizzled row-major layout could store its data
|
in memory. For example, a swizzled row-major layout could store its data
|
||||||
as follows:
|
as follows:
|
||||||
|
|
||||||
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
|
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
|
||||||
@@ -215,9 +215,9 @@ def MmaEncodingAttr : DistributedEncoding<"MmaEncoding"> {
|
|||||||
An encoding for tensors that have been produced by tensor cores.
|
An encoding for tensors that have been produced by tensor cores.
|
||||||
It is characterized by two parameters:
|
It is characterized by two parameters:
|
||||||
- A 'version' which specifies the generation the tensor cores
|
- A 'version' which specifies the generation the tensor cores
|
||||||
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
|
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
|
||||||
and 2 for second-gen tensor cores (Turing/Ampere).
|
and 2 for second-gen tensor cores (Turing/Ampere).
|
||||||
- A `blockTileSize` to indicate how data should be
|
- A `blockTileSize` to indicate how data should be
|
||||||
partitioned between warps.
|
partitioned between warps.
|
||||||
|
|
||||||
// -------------------------------- version = 1 --------------------------- //
|
// -------------------------------- version = 1 --------------------------- //
|
||||||
@@ -229,7 +229,7 @@ https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
|||||||
|
|
||||||
For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||||
|
|
||||||
warp 0
|
warp 0
|
||||||
--------------------------------/\-------------------------------
|
--------------------------------/\-------------------------------
|
||||||
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
|
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
|
||||||
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
|
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
|
||||||
@@ -246,7 +246,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
|||||||
[ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30]
|
[ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30]
|
||||||
[ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31]
|
[ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31]
|
||||||
|
|
||||||
warp 1 = warp0 + 32
|
warp 1 = warp0 + 32
|
||||||
--------------------------------/\-------------------------------
|
--------------------------------/\-------------------------------
|
||||||
[ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38]
|
[ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38]
|
||||||
[ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39]
|
[ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39]
|
||||||
@@ -260,29 +260,29 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
|||||||
For second-gen tensor cores, the implicit warpTileSize is [16, 8].
|
For second-gen tensor cores, the implicit warpTileSize is [16, 8].
|
||||||
Information about this layout can be found in the official PTX documentation
|
Information about this layout can be found in the official PTX documentation
|
||||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
||||||
(mma.16816 section, FP32 accumulator).
|
(mma.16816 section, FP32 accumulator).
|
||||||
|
|
||||||
For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||||
warp 0 warp 1
|
warp 0 warp 1
|
||||||
-----------------/\------------- ----------------/\-------------
|
-----------------/\------------- ----------------/\-------------
|
||||||
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
|
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
|
||||||
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
|
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
|
||||||
[ .............................. ..............................
|
[ .............................. ..............................
|
||||||
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
|
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
|
||||||
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
|
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
|
||||||
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
|
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
|
||||||
[ .............................. ..............................
|
[ .............................. ..............................
|
||||||
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
|
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
|
||||||
|
|
||||||
warp 3 warp 4
|
warp 3 warp 4
|
||||||
----------------/\------------- ----------------/\-------------
|
----------------/\------------- ----------------/\-------------
|
||||||
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
|
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
|
||||||
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
|
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
|
||||||
[ .............................. ...............................
|
[ .............................. ...............................
|
||||||
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
|
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
|
||||||
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
|
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
|
||||||
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
|
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
|
||||||
[ .............................. ...............................
|
[ .............................. ...............................
|
||||||
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
|
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
|
||||||
|
|
||||||
}];
|
}];
|
||||||
@@ -316,7 +316,7 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
|||||||
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.
|
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.
|
||||||
|
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let parameters = (
|
let parameters = (
|
||||||
ins
|
ins
|
||||||
"unsigned":$dim,
|
"unsigned":$dim,
|
||||||
|
@@ -29,6 +29,7 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||||
|
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||||
using ::mlir::triton::gpu::getElemsPerThread;
|
using ::mlir::triton::gpu::getElemsPerThread;
|
||||||
using ::mlir::triton::gpu::getOrder;
|
using ::mlir::triton::gpu::getOrder;
|
||||||
using ::mlir::triton::gpu::getShapePerCTA;
|
using ::mlir::triton::gpu::getShapePerCTA;
|
||||||
@@ -1565,6 +1566,10 @@ public:
|
|||||||
dstLayout.isa<SharedEncodingAttr>()) {
|
dstLayout.isa<SharedEncodingAttr>()) {
|
||||||
return lowerBlockedToShared(op, adaptor, rewriter);
|
return lowerBlockedToShared(op, adaptor, rewriter);
|
||||||
}
|
}
|
||||||
|
if (srcLayout.isa<SharedEncodingAttr>() &&
|
||||||
|
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||||
|
return lowerSharedToDotOperand(op, adaptor, rewriter);
|
||||||
|
}
|
||||||
if ((!srcLayout.isa<BlockedEncodingAttr>() &&
|
if ((!srcLayout.isa<BlockedEncodingAttr>() &&
|
||||||
!srcLayout.isa<MmaEncodingAttr>()) ||
|
!srcLayout.isa<MmaEncodingAttr>()) ||
|
||||||
(!dstLayout.isa<BlockedEncodingAttr>() &&
|
(!dstLayout.isa<BlockedEncodingAttr>() &&
|
||||||
@@ -1572,6 +1577,7 @@ public:
|
|||||||
// TODO: to be implemented
|
// TODO: to be implemented
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1609,6 +1615,11 @@ private:
|
|||||||
LogicalResult lowerBlockedToShared(triton::gpu::ConvertLayoutOp op,
|
LogicalResult lowerBlockedToShared(triton::gpu::ConvertLayoutOp op,
|
||||||
OpAdaptor adaptor,
|
OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const;
|
ConversionPatternRewriter &rewriter) const;
|
||||||
|
|
||||||
|
// shared -> mma_operand
|
||||||
|
LogicalResult
|
||||||
|
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
void ConvertLayoutOpConversion::processReplica(
|
void ConvertLayoutOpConversion::processReplica(
|
||||||
@@ -1915,6 +1926,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
rewriter.replaceOp(op, smemBase);
|
rewriter.replaceOp(op, smemBase);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ====================== dot codegen begin ==========================
|
/// ====================== dot codegen begin ==========================
|
||||||
|
|
||||||
// Data loader for mma.16816 instruction.
|
// Data loader for mma.16816 instruction.
|
||||||
@@ -2383,16 +2395,16 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// Convert to mma.m16n8k16
|
// Convert to mma.m16n8k16
|
||||||
LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adapter,
|
LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const;
|
ConversionPatternRewriter &rewriter) const;
|
||||||
/// Convert to mma.m8n8k4
|
/// Convert to mma.m8n8k4
|
||||||
LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adapter,
|
LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
assert(false && "Not implemented yet.");
|
assert(false && "Not implemented yet.");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adapter,
|
LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
assert(false && "Not implemented yet.");
|
assert(false && "Not implemented yet.");
|
||||||
return failure();
|
return failure();
|
||||||
@@ -2402,28 +2414,18 @@ private:
|
|||||||
struct DotOpConversionHelper {
|
struct DotOpConversionHelper {
|
||||||
using TensorCoreType = DotOpConversion::TensorCoreType;
|
using TensorCoreType = DotOpConversion::TensorCoreType;
|
||||||
|
|
||||||
Value A, B, C, D;
|
|
||||||
MmaEncodingAttr mmaLayout;
|
MmaEncodingAttr mmaLayout;
|
||||||
RankedTensorType ATensorTy, BTensorTy, DTensorTy;
|
|
||||||
MLIRContext *ctx{};
|
MLIRContext *ctx{};
|
||||||
|
|
||||||
explicit DotOpConversionHelper(DotOp dot)
|
explicit DotOpConversionHelper(MmaEncodingAttr mmaLayout)
|
||||||
: dot(dot), mmaType(getMmaType(dot)) {
|
: mmaLayout(mmaLayout) {
|
||||||
A = dot.a();
|
ctx = mmaLayout.getContext();
|
||||||
B = dot.b();
|
|
||||||
C = dot.c();
|
|
||||||
D = dot.d();
|
|
||||||
ctx = dot->getContext();
|
|
||||||
mmaLayout = C.getType()
|
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getEncoding()
|
|
||||||
.cast<MmaEncodingAttr>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load SplatLike C which contains a constVal. It simply returns 4 fp32
|
// Load SplatLike C which contains a constVal. It simply returns 4 fp32
|
||||||
// constVal.
|
// constVal.
|
||||||
SmallVector<Value> loadSplatLikeC(Value C, Location loc,
|
SmallVector<Value> loadSplatLikeC(Value C, Location loc,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
assert(isSplatLike(C));
|
assert(isSplatLike(C));
|
||||||
|
|
||||||
int numRes = getMmaInstrShape()[0] * getMmaInstrShape()[1] / 32;
|
int numRes = getMmaInstrShape()[0] * getMmaInstrShape()[1] / 32;
|
||||||
@@ -2451,6 +2453,11 @@ struct DotOpConversionHelper {
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void deduceMmaType(DotOp op) const { mmaType = getMmaType(op); }
|
||||||
|
void deduceMmaType(Type operandTy) const {
|
||||||
|
mmaType = getTensorCoreTypeFromOperand(operandTy);
|
||||||
|
}
|
||||||
|
|
||||||
Type getShemPtrTy() const {
|
Type getShemPtrTy() const {
|
||||||
switch (mmaType) {
|
switch (mmaType) {
|
||||||
case TensorCoreType::FP32_FP16_FP16_FP32:
|
case TensorCoreType::FP32_FP16_FP16_FP32:
|
||||||
@@ -2554,6 +2561,22 @@ struct DotOpConversionHelper {
|
|||||||
return mmaMatShape.at(mmaType);
|
return mmaMatShape.at(mmaType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deduce the TensorCoreType from either $a or $b's type. This method is not
|
||||||
|
// safe, but we cannot get the DotOp in some getmaMatShape usage case.
|
||||||
|
TensorCoreType getTensorCoreTypeFromOperand(Type operandTy) const {
|
||||||
|
auto tensorTy = operandTy.cast<RankedTensorType>();
|
||||||
|
auto elemTy = tensorTy.getElementType();
|
||||||
|
if (elemTy.isF16())
|
||||||
|
return TensorCoreType::FP32_FP16_FP16_FP32;
|
||||||
|
if (elemTy.isF32())
|
||||||
|
return TensorCoreType::FP32_TF32_TF32_FP32;
|
||||||
|
if (elemTy.isBF16())
|
||||||
|
return TensorCoreType::FP32_BF16_BF16_FP32;
|
||||||
|
if (elemTy.isInteger(8))
|
||||||
|
return TensorCoreType::INT32_INT8_INT8_INT32;
|
||||||
|
return TensorCoreType::NOT_APPLICABLE;
|
||||||
|
}
|
||||||
|
|
||||||
int getVec() const {
|
int getVec() const {
|
||||||
assert(mmaType != TensorCoreType::NOT_APPLICABLE &&
|
assert(mmaType != TensorCoreType::NOT_APPLICABLE &&
|
||||||
"Unknown mma type found.");
|
"Unknown mma type found.");
|
||||||
@@ -2593,7 +2616,7 @@ struct DotOpConversionHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TensorCoreType mmaType;
|
mutable TensorCoreType mmaType{TensorCoreType::NOT_APPLICABLE};
|
||||||
|
|
||||||
// Used on nvidia GPUs mma layout .version == 2
|
// Used on nvidia GPUs mma layout .version == 2
|
||||||
// Refer to
|
// Refer to
|
||||||
@@ -2655,9 +2678,6 @@ private:
|
|||||||
{TensorCoreType::INT32_INT4_INT4_INT32, 32},
|
{TensorCoreType::INT32_INT4_INT4_INT32, 32},
|
||||||
{TensorCoreType::INT32_INT8_INT8_INT32, 16},
|
{TensorCoreType::INT32_INT8_INT8_INT32, 16},
|
||||||
};
|
};
|
||||||
|
|
||||||
private:
|
|
||||||
DotOp dot;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// This class helps to adapt the existing DotOpConversion to the latest
|
// This class helps to adapt the existing DotOpConversion to the latest
|
||||||
@@ -2666,21 +2686,12 @@ private:
|
|||||||
// 1. loading the specific operand matrix(for $a, $b, $c) from smem
|
// 1. loading the specific operand matrix(for $a, $b, $c) from smem
|
||||||
// 2. passing the loaded value and perform the mma codegen
|
// 2. passing the loaded value and perform the mma codegen
|
||||||
struct MMA16816ConversionHelper {
|
struct MMA16816ConversionHelper {
|
||||||
Value A, B, C, D;
|
|
||||||
RankedTensorType aTensorTy, bTensorTy, dTensorTy;
|
|
||||||
ArrayRef<int64_t> aShape, bShape, dShape;
|
|
||||||
MmaEncodingAttr mmaLayout;
|
MmaEncodingAttr mmaLayout;
|
||||||
ArrayRef<unsigned int> wpt;
|
ArrayRef<unsigned int> wpt;
|
||||||
|
|
||||||
int mmaInstrM{-1}, mmaInstrN{-1}, mmaInstrK{-1};
|
|
||||||
int matShapeM{-1}, matShapeN{-1}, matShapeK{-1};
|
|
||||||
int numRepM{-1}, numRepN{-1}, numRepK{-1};
|
|
||||||
Value thread, lane, warp, warpMN, warpN, warpM;
|
Value thread, lane, warp, warpMN, warpN, warpM;
|
||||||
size_t aElemBytes{}, bElemBytes{};
|
|
||||||
|
|
||||||
DotOpConversionHelper helper;
|
DotOpConversionHelper helper;
|
||||||
triton::DotOp op;
|
|
||||||
DotOpAdaptor adapter;
|
|
||||||
ConversionPatternRewriter &rewriter;
|
ConversionPatternRewriter &rewriter;
|
||||||
TypeConverter *typeConverter;
|
TypeConverter *typeConverter;
|
||||||
Location loc;
|
Location loc;
|
||||||
@@ -2688,64 +2699,75 @@ struct MMA16816ConversionHelper {
|
|||||||
|
|
||||||
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
|
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
|
||||||
|
|
||||||
MMA16816ConversionHelper(triton::DotOp op, Value thread, DotOpAdaptor adapter,
|
MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
TypeConverter *typeConverter, Location loc)
|
TypeConverter *typeConverter, Location loc)
|
||||||
: helper(op), op(op), adapter(adapter), rewriter(rewriter),
|
: mmaLayout(mmaLayout), helper(mmaLayout), rewriter(rewriter),
|
||||||
typeConverter(typeConverter), loc(loc), ctx(op.getContext()),
|
typeConverter(typeConverter), loc(loc), ctx(mmaLayout.getContext()),
|
||||||
thread(thread) {
|
thread(thread) {
|
||||||
A = op.a();
|
|
||||||
B = op.b();
|
|
||||||
C = op.c();
|
|
||||||
D = op.getResult();
|
|
||||||
|
|
||||||
aTensorTy = A.getType().cast<RankedTensorType>();
|
|
||||||
bTensorTy = B.getType().cast<RankedTensorType>();
|
|
||||||
dTensorTy = D.getType().cast<RankedTensorType>();
|
|
||||||
|
|
||||||
aShape = aTensorTy.getShape();
|
|
||||||
bShape = bTensorTy.getShape();
|
|
||||||
dShape = dTensorTy.getShape();
|
|
||||||
|
|
||||||
mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
|
|
||||||
|
|
||||||
wpt = mmaLayout.getWarpsPerCTA();
|
wpt = mmaLayout.getWarpsPerCTA();
|
||||||
|
|
||||||
auto mmaInstrShape = helper.getMmaInstrShape();
|
|
||||||
mmaInstrM = mmaInstrShape[0];
|
|
||||||
mmaInstrN = mmaInstrShape[1];
|
|
||||||
mmaInstrK = mmaInstrShape[2];
|
|
||||||
|
|
||||||
auto matShape = helper.getMmaMatShape();
|
|
||||||
matShapeM = matShape[0];
|
|
||||||
matShapeN = matShape[1];
|
|
||||||
matShapeK = matShape[2];
|
|
||||||
|
|
||||||
int NK = aShape[1];
|
|
||||||
// shape / shape_per_cta
|
|
||||||
numRepM = std::max<int>(dShape[0] / (wpt[0] * mmaInstrM), 1);
|
|
||||||
numRepN = std::max<int>(dShape[1] / (wpt[1] * mmaInstrN), 1);
|
|
||||||
numRepK = std::max<int>(NK / mmaInstrK, 1);
|
|
||||||
|
|
||||||
Value _32 = i32_val(32);
|
Value _32 = i32_val(32);
|
||||||
lane = urem(thread, _32);
|
lane = urem(thread, _32);
|
||||||
warp = udiv(thread, _32);
|
warp = udiv(thread, _32);
|
||||||
warpMN = udiv(warp, i32_val(wpt[0]));
|
warpMN = udiv(warp, i32_val(wpt[0]));
|
||||||
warpM = urem(warp, i32_val(wpt[0]));
|
warpM = urem(warp, i32_val(wpt[0]));
|
||||||
warpN = urem(warpMN, i32_val(wpt[1]));
|
warpN = urem(warpMN, i32_val(wpt[1]));
|
||||||
|
}
|
||||||
|
|
||||||
aElemBytes = aTensorTy.getElementTypeBitWidth() / 8;
|
// Get the mmaInstrShape from either $a or $b.
|
||||||
bElemBytes = bTensorTy.getElementTypeBitWidth() / 8;
|
std::tuple<int, int, int> getMmaInstrShape(Type operand) const {
|
||||||
|
helper.deduceMmaType(operand);
|
||||||
|
auto mmaInstrShape = helper.getMmaInstrShape();
|
||||||
|
int mmaInstrM = mmaInstrShape[0];
|
||||||
|
int mmaInstrN = mmaInstrShape[1];
|
||||||
|
int mmaInstrK = mmaInstrShape[2];
|
||||||
|
return std::make_tuple(mmaInstrM, mmaInstrN, mmaInstrK);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<int, int, int> getMmaMatShape(Type operand) const {
|
||||||
|
helper.deduceMmaType(operand);
|
||||||
|
auto matShape = helper.getMmaMatShape();
|
||||||
|
int matShapeM = matShape[0];
|
||||||
|
int matShapeN = matShape[1];
|
||||||
|
int matShapeK = matShape[2];
|
||||||
|
return std::make_tuple(matShapeM, matShapeN, matShapeK);
|
||||||
|
}
|
||||||
|
|
||||||
|
// \param operand is either $a or $b's type.
|
||||||
|
inline int getNumRepM(Type operand, int M) const {
|
||||||
|
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
|
||||||
|
return std::max<int>(M / (wpt[0] * mmaInstrM), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// \param operand is either $a or $b's type.
|
||||||
|
inline int getNumRepN(Type operand, int N) const {
|
||||||
|
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
|
||||||
|
return std::max<int>(N / (wpt[1] * mmaInstrN), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// \param operand is either $a or $b's type.
|
||||||
|
inline int getNumRepK(Type operand, int K) const {
|
||||||
|
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
|
||||||
|
return std::max<int>(K / mmaInstrK, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||||
Value loadA() {
|
Value loadA(Value tensor, Value llTensor) const {
|
||||||
|
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
|
auto shape = aTensorTy.getShape();
|
||||||
|
|
||||||
ValueTable ha;
|
ValueTable ha;
|
||||||
std::function<void(int, int)> loadFn;
|
std::function<void(int, int)> loadFn;
|
||||||
|
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
|
||||||
|
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
|
||||||
|
int numRepM = getNumRepM(aTensorTy, shape[0]);
|
||||||
|
int numRepK = getNumRepK(aTensorTy, shape[1]);
|
||||||
|
|
||||||
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
||||||
// load from smem
|
// load from smem
|
||||||
loadFn = getLoadMatrixFn(
|
loadFn = getLoadMatrixFn(
|
||||||
A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
||||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
||||||
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||||
@@ -2770,10 +2792,17 @@ struct MMA16816ConversionHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||||
Value loadB() {
|
Value loadB(Value tensor, Value llTensor) {
|
||||||
ValueTable hb;
|
ValueTable hb;
|
||||||
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
|
auto shape = tensorTy.getShape();
|
||||||
|
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(tensorTy);
|
||||||
|
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(tensorTy);
|
||||||
|
int numRepK = getNumRepK(tensorTy, shape[0]);
|
||||||
|
int numRepN = getNumRepN(tensorTy, shape[1]);
|
||||||
|
|
||||||
auto loadFn = getLoadMatrixFn(
|
auto loadFn = getLoadMatrixFn(
|
||||||
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
||||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||||
|
|
||||||
@@ -2789,24 +2818,47 @@ struct MMA16816ConversionHelper {
|
|||||||
|
|
||||||
// Loading $c from smem(?) to registers, returns a Value.
|
// Loading $c from smem(?) to registers, returns a Value.
|
||||||
// NOTE Only SplatLike tensor is supported now.
|
// NOTE Only SplatLike tensor is supported now.
|
||||||
Value loadC() {
|
Value loadC(Value tensor) const {
|
||||||
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
|
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
|
||||||
// shared layout or blocked layout, we will support them by expanding
|
// shared layout or blocked layout, we will support them by expanding
|
||||||
// convert_layout.
|
// convert_layout.
|
||||||
auto hc = helper.loadSplatLikeC(C, loc, rewriter);
|
auto hc = helper.loadSplatLikeC(tensor, loc, rewriter);
|
||||||
assert(hc.size() == 4UL && "Only splat-like C is supported now");
|
assert(hc.size() == 4UL && "Only splat-like C is supported now");
|
||||||
return hc[0];
|
return hc[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conduct the Dot conversion.
|
// Conduct the Dot conversion.
|
||||||
// Input the \param a, \param b, \param c, all of them are result of loading.
|
// \param a, \param b, \param c and \param d are DotOp operands.
|
||||||
LogicalResult convertDot(Value a, Value b, Value c) {
|
// \param loadedA, \param loadedB, \param loadedC, all of them are result of
|
||||||
ValueTable ha = getValuesFromDotOperandLayoutStruct(a, numRepM, numRepK);
|
// loading.
|
||||||
|
LogicalResult convertDot(Value a, Value b, Value c, Value d, Value loadedA,
|
||||||
|
Value loadedB, Value loadedC, DotOp op,
|
||||||
|
DotOpAdaptor adaptor) const {
|
||||||
|
helper.deduceMmaType(op);
|
||||||
|
|
||||||
|
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||||
|
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||||
|
auto cTensorTy = c.getType().cast<RankedTensorType>();
|
||||||
|
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
auto aShape = aTensorTy.getShape();
|
||||||
|
auto dShape = dTensorTy.getShape();
|
||||||
|
|
||||||
|
int NK = aShape[1];
|
||||||
|
// shape / shape_per_cta
|
||||||
|
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
|
||||||
|
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
|
||||||
|
int numRepM = getNumRepM(aTensorTy, dShape[0]);
|
||||||
|
int numRepN = getNumRepN(aTensorTy, dShape[1]);
|
||||||
|
int numRepK = getNumRepK(aTensorTy, aShape[1]);
|
||||||
|
|
||||||
|
ValueTable ha =
|
||||||
|
getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK);
|
||||||
ValueTable hb = getValuesFromDotOperandLayoutStruct(
|
ValueTable hb = getValuesFromDotOperandLayoutStruct(
|
||||||
b, std::max(numRepN / 2, 1), numRepK);
|
loadedB, std::max(numRepN / 2, 1), numRepK);
|
||||||
|
|
||||||
const int fcSize = 4 * numRepM * numRepN;
|
const int fcSize = 4 * numRepM * numRepN;
|
||||||
SmallVector<Value> fc(fcSize, c);
|
SmallVector<Value> fc(fcSize, loadedC);
|
||||||
|
|
||||||
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
|
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
|
||||||
unsigned colsPerThread = numRepN * 2;
|
unsigned colsPerThread = numRepN * 2;
|
||||||
@@ -2855,10 +2907,10 @@ struct MMA16816ConversionHelper {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<void(int, int)>
|
std::function<void(int, int)>
|
||||||
getLoadMatrixFn(Value tensor, Value llTensor, int wpt, int kOrder,
|
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
|
||||||
ArrayRef<int> instrShape, ArrayRef<int> matShape,
|
int wpt, int kOrder, ArrayRef<int> instrShape,
|
||||||
Value warpId, ValueTable &vals) {
|
ArrayRef<int> matShape, Value warpId,
|
||||||
|
ValueTable &vals) const {
|
||||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
// We assumes that the input operand of Dot should be from shared layout.
|
// We assumes that the input operand of Dot should be from shared layout.
|
||||||
// TODO(Superjomn) Consider other layouts if needed later.
|
// TODO(Superjomn) Consider other layouts if needed later.
|
||||||
@@ -2928,7 +2980,7 @@ private:
|
|||||||
// i \in [0, n0) and j \in [0, n1)
|
// i \in [0, n0) and j \in [0, n1)
|
||||||
// There should be \param n0 * \param n1 elements in the output Struct.
|
// There should be \param n0 * \param n1 elements in the output Struct.
|
||||||
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
|
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
|
||||||
int n1) {
|
int n1) const {
|
||||||
std::vector<Value> elems;
|
std::vector<Value> elems;
|
||||||
for (unsigned m = 0; m < n0; ++m)
|
for (unsigned m = 0; m < n0; ++m)
|
||||||
for (unsigned k = 0; k < n1; ++k) {
|
for (unsigned k = 0; k < n1; ++k) {
|
||||||
@@ -2940,7 +2992,7 @@ private:
|
|||||||
|
|
||||||
assert(!elems.empty());
|
assert(!elems.empty());
|
||||||
|
|
||||||
Type fp16Ty = aTensorTy.getElementType();
|
Type fp16Ty = type::f16Ty(ctx);
|
||||||
Type fp16x2Ty = vec_ty(fp16Ty, 2);
|
Type fp16x2Ty = vec_ty(fp16Ty, 2);
|
||||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||||
ctx, SmallVector<Type>(elems.size(), fp16x2Ty));
|
ctx, SmallVector<Type>(elems.size(), fp16x2Ty));
|
||||||
@@ -2948,7 +3000,8 @@ private:
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1) {
|
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0,
|
||||||
|
int n1) const {
|
||||||
auto elems = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct(
|
auto elems = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct(
|
||||||
loc, value, rewriter);
|
loc, value, rewriter);
|
||||||
|
|
||||||
@@ -2966,18 +3019,79 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||||
|
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
Value src = op.src();
|
||||||
|
Value dst = op.result();
|
||||||
|
auto srcTensorTy = src.getType().cast<RankedTensorType>();
|
||||||
|
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||||
|
auto dotOperandLayout =
|
||||||
|
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||||
|
MmaEncodingAttr mmaLayout =
|
||||||
|
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
|
||||||
|
assert(mmaLayout);
|
||||||
|
|
||||||
|
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
||||||
|
rewriter, getTypeConverter(), op.getLoc());
|
||||||
|
|
||||||
|
Value res;
|
||||||
|
if (dotOperandLayout.getOpIdx() == 0) {
|
||||||
|
// operand $a
|
||||||
|
res = mmaHelper.loadA(src, adaptor.src());
|
||||||
|
} else if (dotOperandLayout.getOpIdx() == 1) {
|
||||||
|
// operand $b
|
||||||
|
res = mmaHelper.loadB(src, adaptor.src());
|
||||||
|
} else if (dotOperandLayout.getOpIdx() == 2) {
|
||||||
|
// operand $c
|
||||||
|
res = mmaHelper.loadC(src);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, res);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
MMA16816ConversionHelper mmaHelper(op, getThreadId(rewriter, loc), adapter,
|
auto mmaLayout = op.getResult()
|
||||||
|
.getType()
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getEncoding()
|
||||||
|
.cast<MmaEncodingAttr>();
|
||||||
|
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
||||||
rewriter, getTypeConverter(), loc);
|
rewriter, getTypeConverter(), loc);
|
||||||
|
|
||||||
auto A = mmaHelper.loadA();
|
Value A = op.a();
|
||||||
auto B = mmaHelper.loadB();
|
Value B = op.b();
|
||||||
auto C = mmaHelper.loadC();
|
Value C = op.c();
|
||||||
|
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||||
|
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
return mmaHelper.convertDot(A, B, C);
|
Value loadedA, loadedB, loadedC;
|
||||||
|
// We support two kinds of operand layouts: 1. both $a, $b are dot_operand
|
||||||
|
// layout, 2. both of them are shared layout.
|
||||||
|
if (ATensorTy.getEncoding().isa<DotOperandEncodingAttr>()) {
|
||||||
|
assert(BTensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
|
||||||
|
"Both $a and %b should be DotOperand layout.");
|
||||||
|
loadedA = adaptor.a();
|
||||||
|
loadedB = adaptor.b();
|
||||||
|
} else {
|
||||||
|
loadedA = mmaHelper.loadA(op.a(), adaptor.a());
|
||||||
|
loadedB = mmaHelper.loadB(op.b(), adaptor.b());
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO[Superjomn]: Process C as a mma layout.
|
||||||
|
// Currently, C is simply treated as a Splat Op, and the data layout is not
|
||||||
|
// mattered.
|
||||||
|
loadedC = mmaHelper.loadC(op.c());
|
||||||
|
|
||||||
|
return mmaHelper.convertDot(A, B, C, op.d(), loadedA, loadedB, loadedC, op,
|
||||||
|
adaptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ====================== mma codegen end ============================
|
/// ====================== mma codegen end ============================
|
||||||
|
@@ -610,4 +610,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
|||||||
NamedAttribute attr) {
|
NamedAttribute attr) {
|
||||||
// TODO: fill this.
|
// TODO: fill this.
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@@ -714,3 +714,27 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||||
|
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||||
|
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
|
||||||
|
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||||
|
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||||
|
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
|
||||||
|
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
|
||||||
|
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||||
|
%a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a>
|
||||||
|
%b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b>
|
||||||
|
|
||||||
|
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma>
|
||||||
|
%38 = triton_gpu.convert_layout %28 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked>
|
||||||
|
|
||||||
|
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>, #blocked>
|
||||||
|
%36 = tt.broadcast %30 : (tensor<128x1x!tt.ptr<f32>, #blocked>) -> tensor<128x256x!tt.ptr<f32>, #blocked>
|
||||||
|
tt.store %36, %38 : tensor<128x256xf32, #blocked>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user