[Triton-MLIR][BACKEND] add convert_layout[shared->dot_op] converstion to adapt DotOperand layout (#786)

This PR helps to

1. Adapt the existing DotOp conversion to the design of the new
DotOperand layout,
2. Making the DotOp conversion work with both shared-layout inputs case
and dotoperand-layout inputs case for further upstream switch.
This commit is contained in:
Yan Chunwei
2022-10-24 11:40:13 +08:00
committed by GitHub
parent 3aa8296b06
commit 877844de4f
4 changed files with 241 additions and 103 deletions

View File

@@ -52,7 +52,7 @@ different cuda threads in the programs, via shared memory. In other words,
for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
In order to avoid shared memory bank conflicts, elements may be swizzled
in memory. For example, a swizzled row-major layout could store its data
in memory. For example, a swizzled row-major layout could store its data
as follows:
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
@@ -215,9 +215,9 @@ def MmaEncodingAttr : DistributedEncoding<"MmaEncoding"> {
An encoding for tensors that have been produced by tensor cores.
It is characterized by two parameters:
- A 'version' which specifies the generation the tensor cores
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
and 2 for second-gen tensor cores (Turing/Ampere).
- A `blockTileSize` to indicate how data should be
- A `blockTileSize` to indicate how data should be
partitioned between warps.
// -------------------------------- version = 1 --------------------------- //
@@ -229,7 +229,7 @@ https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
For example, the matrix L corresponding to blockTileSize=[32,16] is:
warp 0
warp 0
--------------------------------/\-------------------------------
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
@@ -246,7 +246,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
[ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30]
[ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31]
warp 1 = warp0 + 32
warp 1 = warp0 + 32
--------------------------------/\-------------------------------
[ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38]
[ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39]
@@ -260,29 +260,29 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
For second-gen tensor cores, the implicit warpTileSize is [16, 8].
Information about this layout can be found in the official PTX documentation
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
(mma.16816 section, FP32 accumulator).
(mma.16816 section, FP32 accumulator).
For example, the matrix L corresponding to blockTileSize=[32,16] is:
warp 0 warp 1
-----------------/\------------- ----------------/\-------------
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
[ .............................. ..............................
[ .............................. ..............................
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
[ .............................. ..............................
[ .............................. ..............................
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
warp 3 warp 4
----------------/\------------- ----------------/\-------------
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
[ .............................. ...............................
[ .............................. ...............................
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
[ .............................. ...............................
[ .............................. ...............................
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
}];
@@ -316,7 +316,7 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.
}];
let parameters = (
ins
"unsigned":$dim,

View File

@@ -29,6 +29,7 @@
using namespace mlir;
using namespace mlir::triton;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
@@ -1565,6 +1566,10 @@ public:
dstLayout.isa<SharedEncodingAttr>()) {
return lowerBlockedToShared(op, adaptor, rewriter);
}
if (srcLayout.isa<SharedEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerSharedToDotOperand(op, adaptor, rewriter);
}
if ((!srcLayout.isa<BlockedEncodingAttr>() &&
!srcLayout.isa<MmaEncodingAttr>()) ||
(!dstLayout.isa<BlockedEncodingAttr>() &&
@@ -1572,6 +1577,7 @@ public:
// TODO: to be implemented
return failure();
}
return lowerDistributedToDistributed(op, adaptor, rewriter);
}
@@ -1609,6 +1615,11 @@ private:
LogicalResult lowerBlockedToShared(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const;
// shared -> mma_operand
LogicalResult
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const;
};
void ConvertLayoutOpConversion::processReplica(
@@ -1915,6 +1926,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
rewriter.replaceOp(op, smemBase);
return success();
}
/// ====================== dot codegen begin ==========================
// Data loader for mma.16816 instruction.
@@ -2383,16 +2395,16 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
private:
// Convert to mma.m16n8k16
LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adapter,
LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const;
/// Convert to mma.m8n8k4
LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adapter,
LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
assert(false && "Not implemented yet.");
return failure();
}
LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adapter,
LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
assert(false && "Not implemented yet.");
return failure();
@@ -2402,28 +2414,18 @@ private:
struct DotOpConversionHelper {
using TensorCoreType = DotOpConversion::TensorCoreType;
Value A, B, C, D;
MmaEncodingAttr mmaLayout;
RankedTensorType ATensorTy, BTensorTy, DTensorTy;
MLIRContext *ctx{};
explicit DotOpConversionHelper(DotOp dot)
: dot(dot), mmaType(getMmaType(dot)) {
A = dot.a();
B = dot.b();
C = dot.c();
D = dot.d();
ctx = dot->getContext();
mmaLayout = C.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<MmaEncodingAttr>();
explicit DotOpConversionHelper(MmaEncodingAttr mmaLayout)
: mmaLayout(mmaLayout) {
ctx = mmaLayout.getContext();
}
// Load SplatLike C which contains a constVal. It simply returns 4 fp32
// constVal.
SmallVector<Value> loadSplatLikeC(Value C, Location loc,
ConversionPatternRewriter &rewriter) {
ConversionPatternRewriter &rewriter) const {
assert(isSplatLike(C));
int numRes = getMmaInstrShape()[0] * getMmaInstrShape()[1] / 32;
@@ -2451,6 +2453,11 @@ struct DotOpConversionHelper {
return {};
}
void deduceMmaType(DotOp op) const { mmaType = getMmaType(op); }
void deduceMmaType(Type operandTy) const {
mmaType = getTensorCoreTypeFromOperand(operandTy);
}
Type getShemPtrTy() const {
switch (mmaType) {
case TensorCoreType::FP32_FP16_FP16_FP32:
@@ -2554,6 +2561,22 @@ struct DotOpConversionHelper {
return mmaMatShape.at(mmaType);
}
// Deduce the TensorCoreType from either $a or $b's type. This method is not
// safe, but we cannot get the DotOp in some getmaMatShape usage case.
TensorCoreType getTensorCoreTypeFromOperand(Type operandTy) const {
auto tensorTy = operandTy.cast<RankedTensorType>();
auto elemTy = tensorTy.getElementType();
if (elemTy.isF16())
return TensorCoreType::FP32_FP16_FP16_FP32;
if (elemTy.isF32())
return TensorCoreType::FP32_TF32_TF32_FP32;
if (elemTy.isBF16())
return TensorCoreType::FP32_BF16_BF16_FP32;
if (elemTy.isInteger(8))
return TensorCoreType::INT32_INT8_INT8_INT32;
return TensorCoreType::NOT_APPLICABLE;
}
int getVec() const {
assert(mmaType != TensorCoreType::NOT_APPLICABLE &&
"Unknown mma type found.");
@@ -2593,7 +2616,7 @@ struct DotOpConversionHelper {
}
private:
TensorCoreType mmaType;
mutable TensorCoreType mmaType{TensorCoreType::NOT_APPLICABLE};
// Used on nvidia GPUs mma layout .version == 2
// Refer to
@@ -2655,9 +2678,6 @@ private:
{TensorCoreType::INT32_INT4_INT4_INT32, 32},
{TensorCoreType::INT32_INT8_INT8_INT32, 16},
};
private:
DotOp dot;
};
// This class helps to adapt the existing DotOpConversion to the latest
@@ -2666,21 +2686,12 @@ private:
// 1. loading the specific operand matrix(for $a, $b, $c) from smem
// 2. passing the loaded value and perform the mma codegen
struct MMA16816ConversionHelper {
Value A, B, C, D;
RankedTensorType aTensorTy, bTensorTy, dTensorTy;
ArrayRef<int64_t> aShape, bShape, dShape;
MmaEncodingAttr mmaLayout;
ArrayRef<unsigned int> wpt;
int mmaInstrM{-1}, mmaInstrN{-1}, mmaInstrK{-1};
int matShapeM{-1}, matShapeN{-1}, matShapeK{-1};
int numRepM{-1}, numRepN{-1}, numRepK{-1};
Value thread, lane, warp, warpMN, warpN, warpM;
size_t aElemBytes{}, bElemBytes{};
DotOpConversionHelper helper;
triton::DotOp op;
DotOpAdaptor adapter;
ConversionPatternRewriter &rewriter;
TypeConverter *typeConverter;
Location loc;
@@ -2688,64 +2699,75 @@ struct MMA16816ConversionHelper {
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
MMA16816ConversionHelper(triton::DotOp op, Value thread, DotOpAdaptor adapter,
MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread,
ConversionPatternRewriter &rewriter,
TypeConverter *typeConverter, Location loc)
: helper(op), op(op), adapter(adapter), rewriter(rewriter),
typeConverter(typeConverter), loc(loc), ctx(op.getContext()),
: mmaLayout(mmaLayout), helper(mmaLayout), rewriter(rewriter),
typeConverter(typeConverter), loc(loc), ctx(mmaLayout.getContext()),
thread(thread) {
A = op.a();
B = op.b();
C = op.c();
D = op.getResult();
aTensorTy = A.getType().cast<RankedTensorType>();
bTensorTy = B.getType().cast<RankedTensorType>();
dTensorTy = D.getType().cast<RankedTensorType>();
aShape = aTensorTy.getShape();
bShape = bTensorTy.getShape();
dShape = dTensorTy.getShape();
mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
wpt = mmaLayout.getWarpsPerCTA();
auto mmaInstrShape = helper.getMmaInstrShape();
mmaInstrM = mmaInstrShape[0];
mmaInstrN = mmaInstrShape[1];
mmaInstrK = mmaInstrShape[2];
auto matShape = helper.getMmaMatShape();
matShapeM = matShape[0];
matShapeN = matShape[1];
matShapeK = matShape[2];
int NK = aShape[1];
// shape / shape_per_cta
numRepM = std::max<int>(dShape[0] / (wpt[0] * mmaInstrM), 1);
numRepN = std::max<int>(dShape[1] / (wpt[1] * mmaInstrN), 1);
numRepK = std::max<int>(NK / mmaInstrK, 1);
Value _32 = i32_val(32);
lane = urem(thread, _32);
warp = udiv(thread, _32);
warpMN = udiv(warp, i32_val(wpt[0]));
warpM = urem(warp, i32_val(wpt[0]));
warpN = urem(warpMN, i32_val(wpt[1]));
}
aElemBytes = aTensorTy.getElementTypeBitWidth() / 8;
bElemBytes = bTensorTy.getElementTypeBitWidth() / 8;
// Get the mmaInstrShape from either $a or $b.
std::tuple<int, int, int> getMmaInstrShape(Type operand) const {
helper.deduceMmaType(operand);
auto mmaInstrShape = helper.getMmaInstrShape();
int mmaInstrM = mmaInstrShape[0];
int mmaInstrN = mmaInstrShape[1];
int mmaInstrK = mmaInstrShape[2];
return std::make_tuple(mmaInstrM, mmaInstrN, mmaInstrK);
}
std::tuple<int, int, int> getMmaMatShape(Type operand) const {
helper.deduceMmaType(operand);
auto matShape = helper.getMmaMatShape();
int matShapeM = matShape[0];
int matShapeN = matShape[1];
int matShapeK = matShape[2];
return std::make_tuple(matShapeM, matShapeN, matShapeK);
}
// \param operand is either $a or $b's type.
inline int getNumRepM(Type operand, int M) const {
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
return std::max<int>(M / (wpt[0] * mmaInstrM), 1);
}
// \param operand is either $a or $b's type.
inline int getNumRepN(Type operand, int N) const {
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
return std::max<int>(N / (wpt[1] * mmaInstrN), 1);
}
// \param operand is either $a or $b's type.
inline int getNumRepK(Type operand, int K) const {
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
return std::max<int>(K / mmaInstrK, 1);
}
// Loading $a from smem to registers, returns a LLVM::Struct.
Value loadA() {
Value loadA(Value tensor, Value llTensor) const {
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = aTensorTy.getShape();
ValueTable ha;
std::function<void(int, int)> loadFn;
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
int numRepM = getNumRepM(aTensorTy, shape[0]);
int numRepK = getNumRepK(aTensorTy, shape[1]);
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
// load from smem
loadFn = getLoadMatrixFn(
A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
@@ -2770,10 +2792,17 @@ struct MMA16816ConversionHelper {
}
// Loading $b from smem to registers, returns a LLVM::Struct.
Value loadB() {
Value loadB(Value tensor, Value llTensor) {
ValueTable hb;
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = tensorTy.getShape();
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(tensorTy);
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(tensorTy);
int numRepK = getNumRepK(tensorTy, shape[0]);
int numRepN = getNumRepN(tensorTy, shape[1]);
auto loadFn = getLoadMatrixFn(
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
@@ -2789,24 +2818,47 @@ struct MMA16816ConversionHelper {
// Loading $c from smem(?) to registers, returns a Value.
// NOTE Only SplatLike tensor is supported now.
Value loadC() {
Value loadC(Value tensor) const {
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
// shared layout or blocked layout, we will support them by expanding
// convert_layout.
auto hc = helper.loadSplatLikeC(C, loc, rewriter);
auto hc = helper.loadSplatLikeC(tensor, loc, rewriter);
assert(hc.size() == 4UL && "Only splat-like C is supported now");
return hc[0];
}
// Conduct the Dot conversion.
// Input the \param a, \param b, \param c, all of them are result of loading.
LogicalResult convertDot(Value a, Value b, Value c) {
ValueTable ha = getValuesFromDotOperandLayoutStruct(a, numRepM, numRepK);
// \param a, \param b, \param c and \param d are DotOp operands.
// \param loadedA, \param loadedB, \param loadedC, all of them are result of
// loading.
LogicalResult convertDot(Value a, Value b, Value c, Value d, Value loadedA,
Value loadedB, Value loadedC, DotOp op,
DotOpAdaptor adaptor) const {
helper.deduceMmaType(op);
auto aTensorTy = a.getType().cast<RankedTensorType>();
auto bTensorTy = b.getType().cast<RankedTensorType>();
auto cTensorTy = c.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>();
auto aShape = aTensorTy.getShape();
auto dShape = dTensorTy.getShape();
int NK = aShape[1];
// shape / shape_per_cta
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
int numRepM = getNumRepM(aTensorTy, dShape[0]);
int numRepN = getNumRepN(aTensorTy, dShape[1]);
int numRepK = getNumRepK(aTensorTy, aShape[1]);
ValueTable ha =
getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK);
ValueTable hb = getValuesFromDotOperandLayoutStruct(
b, std::max(numRepN / 2, 1), numRepK);
loadedB, std::max(numRepN / 2, 1), numRepK);
const int fcSize = 4 * numRepM * numRepN;
SmallVector<Value> fc(fcSize, c);
SmallVector<Value> fc(fcSize, loadedC);
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
unsigned colsPerThread = numRepN * 2;
@@ -2855,10 +2907,10 @@ struct MMA16816ConversionHelper {
private:
std::function<void(int, int)>
getLoadMatrixFn(Value tensor, Value llTensor, int wpt, int kOrder,
ArrayRef<int> instrShape, ArrayRef<int> matShape,
Value warpId, ValueTable &vals) {
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
int wpt, int kOrder, ArrayRef<int> instrShape,
ArrayRef<int> matShape, Value warpId,
ValueTable &vals) const {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
// We assumes that the input operand of Dot should be from shared layout.
// TODO(Superjomn) Consider other layouts if needed later.
@@ -2928,7 +2980,7 @@ private:
// i \in [0, n0) and j \in [0, n1)
// There should be \param n0 * \param n1 elements in the output Struct.
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
int n1) {
int n1) const {
std::vector<Value> elems;
for (unsigned m = 0; m < n0; ++m)
for (unsigned k = 0; k < n1; ++k) {
@@ -2940,7 +2992,7 @@ private:
assert(!elems.empty());
Type fp16Ty = aTensorTy.getElementType();
Type fp16Ty = type::f16Ty(ctx);
Type fp16x2Ty = vec_ty(fp16Ty, 2);
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems.size(), fp16x2Ty));
@@ -2948,7 +3000,8 @@ private:
return result;
}
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1) {
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0,
int n1) const {
auto elems = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct(
loc, value, rewriter);
@@ -2966,18 +3019,79 @@ private:
}
};
LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto srcTensorTy = src.getType().cast<RankedTensorType>();
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto dotOperandLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
MmaEncodingAttr mmaLayout =
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
assert(mmaLayout);
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
rewriter, getTypeConverter(), op.getLoc());
Value res;
if (dotOperandLayout.getOpIdx() == 0) {
// operand $a
res = mmaHelper.loadA(src, adaptor.src());
} else if (dotOperandLayout.getOpIdx() == 1) {
// operand $b
res = mmaHelper.loadB(src, adaptor.src());
} else if (dotOperandLayout.getOpIdx() == 2) {
// operand $c
res = mmaHelper.loadC(src);
}
rewriter.replaceOp(op, res);
return success();
}
LogicalResult
DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
MMA16816ConversionHelper mmaHelper(op, getThreadId(rewriter, loc), adapter,
auto mmaLayout = op.getResult()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<MmaEncodingAttr>();
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
rewriter, getTypeConverter(), loc);
auto A = mmaHelper.loadA();
auto B = mmaHelper.loadB();
auto C = mmaHelper.loadC();
Value A = op.a();
Value B = op.b();
Value C = op.c();
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
return mmaHelper.convertDot(A, B, C);
Value loadedA, loadedB, loadedC;
// We support two kinds of operand layouts: 1. both $a, $b are dot_operand
// layout, 2. both of them are shared layout.
if (ATensorTy.getEncoding().isa<DotOperandEncodingAttr>()) {
assert(BTensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
"Both $a and %b should be DotOperand layout.");
loadedA = adaptor.a();
loadedB = adaptor.b();
} else {
loadedA = mmaHelper.loadA(op.a(), adaptor.a());
loadedB = mmaHelper.loadB(op.b(), adaptor.b());
}
// TODO[Superjomn]: Process C as a mma layout.
// Currently, C is simply treated as a Splat Op, and the data layout is not
// mattered.
loadedC = mmaHelper.loadC(op.c());
return mmaHelper.convertDot(A, B, C, op.d(), loadedA, loadedB, loadedC, op,
adaptor);
}
/// ====================== mma codegen end ============================

View File

@@ -610,4 +610,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// TODO: fill this.
return success();
}
}

View File

@@ -714,3 +714,27 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
return
}
}
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
%a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a>
%b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b>
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma>
%38 = triton_gpu.convert_layout %28 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked>
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>, #blocked>
%36 = tt.broadcast %30 : (tensor<128x1x!tt.ptr<f32>, #blocked>) -> tensor<128x256x!tt.ptr<f32>, #blocked>
tt.store %36, %38 : tensor<128x256xf32, #blocked>
return
}
}