[Triton-MLIR][BACKEND] add convert_layout[shared->dot_op] converstion to adapt DotOperand layout (#786)
This PR helps to 1. Adapt the existing DotOp conversion to the design of the new DotOperand layout, 2. Making the DotOp conversion work with both shared-layout inputs case and dotoperand-layout inputs case for further upstream switch.
This commit is contained in:
@@ -29,6 +29,7 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
@@ -1565,6 +1566,10 @@ public:
|
||||
dstLayout.isa<SharedEncodingAttr>()) {
|
||||
return lowerBlockedToShared(op, adaptor, rewriter);
|
||||
}
|
||||
if (srcLayout.isa<SharedEncodingAttr>() &&
|
||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
return lowerSharedToDotOperand(op, adaptor, rewriter);
|
||||
}
|
||||
if ((!srcLayout.isa<BlockedEncodingAttr>() &&
|
||||
!srcLayout.isa<MmaEncodingAttr>()) ||
|
||||
(!dstLayout.isa<BlockedEncodingAttr>() &&
|
||||
@@ -1572,6 +1577,7 @@ public:
|
||||
// TODO: to be implemented
|
||||
return failure();
|
||||
}
|
||||
|
||||
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
||||
}
|
||||
|
||||
@@ -1609,6 +1615,11 @@ private:
|
||||
LogicalResult lowerBlockedToShared(triton::gpu::ConvertLayoutOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// shared -> mma_operand
|
||||
LogicalResult
|
||||
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
};
|
||||
|
||||
void ConvertLayoutOpConversion::processReplica(
|
||||
@@ -1915,6 +1926,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
rewriter.replaceOp(op, smemBase);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// ====================== dot codegen begin ==========================
|
||||
|
||||
// Data loader for mma.16816 instruction.
|
||||
@@ -2383,16 +2395,16 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
|
||||
private:
|
||||
// Convert to mma.m16n8k16
|
||||
LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adapter,
|
||||
LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
/// Convert to mma.m8n8k4
|
||||
LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adapter,
|
||||
LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
assert(false && "Not implemented yet.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adapter,
|
||||
LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
assert(false && "Not implemented yet.");
|
||||
return failure();
|
||||
@@ -2402,28 +2414,18 @@ private:
|
||||
struct DotOpConversionHelper {
|
||||
using TensorCoreType = DotOpConversion::TensorCoreType;
|
||||
|
||||
Value A, B, C, D;
|
||||
MmaEncodingAttr mmaLayout;
|
||||
RankedTensorType ATensorTy, BTensorTy, DTensorTy;
|
||||
MLIRContext *ctx{};
|
||||
|
||||
explicit DotOpConversionHelper(DotOp dot)
|
||||
: dot(dot), mmaType(getMmaType(dot)) {
|
||||
A = dot.a();
|
||||
B = dot.b();
|
||||
C = dot.c();
|
||||
D = dot.d();
|
||||
ctx = dot->getContext();
|
||||
mmaLayout = C.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<MmaEncodingAttr>();
|
||||
explicit DotOpConversionHelper(MmaEncodingAttr mmaLayout)
|
||||
: mmaLayout(mmaLayout) {
|
||||
ctx = mmaLayout.getContext();
|
||||
}
|
||||
|
||||
// Load SplatLike C which contains a constVal. It simply returns 4 fp32
|
||||
// constVal.
|
||||
SmallVector<Value> loadSplatLikeC(Value C, Location loc,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
assert(isSplatLike(C));
|
||||
|
||||
int numRes = getMmaInstrShape()[0] * getMmaInstrShape()[1] / 32;
|
||||
@@ -2451,6 +2453,11 @@ struct DotOpConversionHelper {
|
||||
return {};
|
||||
}
|
||||
|
||||
void deduceMmaType(DotOp op) const { mmaType = getMmaType(op); }
|
||||
void deduceMmaType(Type operandTy) const {
|
||||
mmaType = getTensorCoreTypeFromOperand(operandTy);
|
||||
}
|
||||
|
||||
Type getShemPtrTy() const {
|
||||
switch (mmaType) {
|
||||
case TensorCoreType::FP32_FP16_FP16_FP32:
|
||||
@@ -2554,6 +2561,22 @@ struct DotOpConversionHelper {
|
||||
return mmaMatShape.at(mmaType);
|
||||
}
|
||||
|
||||
// Deduce the TensorCoreType from either $a or $b's type. This method is not
|
||||
// safe, but we cannot get the DotOp in some getmaMatShape usage case.
|
||||
TensorCoreType getTensorCoreTypeFromOperand(Type operandTy) const {
|
||||
auto tensorTy = operandTy.cast<RankedTensorType>();
|
||||
auto elemTy = tensorTy.getElementType();
|
||||
if (elemTy.isF16())
|
||||
return TensorCoreType::FP32_FP16_FP16_FP32;
|
||||
if (elemTy.isF32())
|
||||
return TensorCoreType::FP32_TF32_TF32_FP32;
|
||||
if (elemTy.isBF16())
|
||||
return TensorCoreType::FP32_BF16_BF16_FP32;
|
||||
if (elemTy.isInteger(8))
|
||||
return TensorCoreType::INT32_INT8_INT8_INT32;
|
||||
return TensorCoreType::NOT_APPLICABLE;
|
||||
}
|
||||
|
||||
int getVec() const {
|
||||
assert(mmaType != TensorCoreType::NOT_APPLICABLE &&
|
||||
"Unknown mma type found.");
|
||||
@@ -2593,7 +2616,7 @@ struct DotOpConversionHelper {
|
||||
}
|
||||
|
||||
private:
|
||||
TensorCoreType mmaType;
|
||||
mutable TensorCoreType mmaType{TensorCoreType::NOT_APPLICABLE};
|
||||
|
||||
// Used on nvidia GPUs mma layout .version == 2
|
||||
// Refer to
|
||||
@@ -2655,9 +2678,6 @@ private:
|
||||
{TensorCoreType::INT32_INT4_INT4_INT32, 32},
|
||||
{TensorCoreType::INT32_INT8_INT8_INT32, 16},
|
||||
};
|
||||
|
||||
private:
|
||||
DotOp dot;
|
||||
};
|
||||
|
||||
// This class helps to adapt the existing DotOpConversion to the latest
|
||||
@@ -2666,21 +2686,12 @@ private:
|
||||
// 1. loading the specific operand matrix(for $a, $b, $c) from smem
|
||||
// 2. passing the loaded value and perform the mma codegen
|
||||
struct MMA16816ConversionHelper {
|
||||
Value A, B, C, D;
|
||||
RankedTensorType aTensorTy, bTensorTy, dTensorTy;
|
||||
ArrayRef<int64_t> aShape, bShape, dShape;
|
||||
MmaEncodingAttr mmaLayout;
|
||||
ArrayRef<unsigned int> wpt;
|
||||
|
||||
int mmaInstrM{-1}, mmaInstrN{-1}, mmaInstrK{-1};
|
||||
int matShapeM{-1}, matShapeN{-1}, matShapeK{-1};
|
||||
int numRepM{-1}, numRepN{-1}, numRepK{-1};
|
||||
Value thread, lane, warp, warpMN, warpN, warpM;
|
||||
size_t aElemBytes{}, bElemBytes{};
|
||||
|
||||
DotOpConversionHelper helper;
|
||||
triton::DotOp op;
|
||||
DotOpAdaptor adapter;
|
||||
ConversionPatternRewriter &rewriter;
|
||||
TypeConverter *typeConverter;
|
||||
Location loc;
|
||||
@@ -2688,64 +2699,75 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
|
||||
|
||||
MMA16816ConversionHelper(triton::DotOp op, Value thread, DotOpAdaptor adapter,
|
||||
MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
TypeConverter *typeConverter, Location loc)
|
||||
: helper(op), op(op), adapter(adapter), rewriter(rewriter),
|
||||
typeConverter(typeConverter), loc(loc), ctx(op.getContext()),
|
||||
: mmaLayout(mmaLayout), helper(mmaLayout), rewriter(rewriter),
|
||||
typeConverter(typeConverter), loc(loc), ctx(mmaLayout.getContext()),
|
||||
thread(thread) {
|
||||
A = op.a();
|
||||
B = op.b();
|
||||
C = op.c();
|
||||
D = op.getResult();
|
||||
|
||||
aTensorTy = A.getType().cast<RankedTensorType>();
|
||||
bTensorTy = B.getType().cast<RankedTensorType>();
|
||||
dTensorTy = D.getType().cast<RankedTensorType>();
|
||||
|
||||
aShape = aTensorTy.getShape();
|
||||
bShape = bTensorTy.getShape();
|
||||
dShape = dTensorTy.getShape();
|
||||
|
||||
mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
|
||||
wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
auto mmaInstrShape = helper.getMmaInstrShape();
|
||||
mmaInstrM = mmaInstrShape[0];
|
||||
mmaInstrN = mmaInstrShape[1];
|
||||
mmaInstrK = mmaInstrShape[2];
|
||||
|
||||
auto matShape = helper.getMmaMatShape();
|
||||
matShapeM = matShape[0];
|
||||
matShapeN = matShape[1];
|
||||
matShapeK = matShape[2];
|
||||
|
||||
int NK = aShape[1];
|
||||
// shape / shape_per_cta
|
||||
numRepM = std::max<int>(dShape[0] / (wpt[0] * mmaInstrM), 1);
|
||||
numRepN = std::max<int>(dShape[1] / (wpt[1] * mmaInstrN), 1);
|
||||
numRepK = std::max<int>(NK / mmaInstrK, 1);
|
||||
|
||||
Value _32 = i32_val(32);
|
||||
lane = urem(thread, _32);
|
||||
warp = udiv(thread, _32);
|
||||
warpMN = udiv(warp, i32_val(wpt[0]));
|
||||
warpM = urem(warp, i32_val(wpt[0]));
|
||||
warpN = urem(warpMN, i32_val(wpt[1]));
|
||||
}
|
||||
|
||||
aElemBytes = aTensorTy.getElementTypeBitWidth() / 8;
|
||||
bElemBytes = bTensorTy.getElementTypeBitWidth() / 8;
|
||||
// Get the mmaInstrShape from either $a or $b.
|
||||
std::tuple<int, int, int> getMmaInstrShape(Type operand) const {
|
||||
helper.deduceMmaType(operand);
|
||||
auto mmaInstrShape = helper.getMmaInstrShape();
|
||||
int mmaInstrM = mmaInstrShape[0];
|
||||
int mmaInstrN = mmaInstrShape[1];
|
||||
int mmaInstrK = mmaInstrShape[2];
|
||||
return std::make_tuple(mmaInstrM, mmaInstrN, mmaInstrK);
|
||||
}
|
||||
|
||||
std::tuple<int, int, int> getMmaMatShape(Type operand) const {
|
||||
helper.deduceMmaType(operand);
|
||||
auto matShape = helper.getMmaMatShape();
|
||||
int matShapeM = matShape[0];
|
||||
int matShapeN = matShape[1];
|
||||
int matShapeK = matShape[2];
|
||||
return std::make_tuple(matShapeM, matShapeN, matShapeK);
|
||||
}
|
||||
|
||||
// \param operand is either $a or $b's type.
|
||||
inline int getNumRepM(Type operand, int M) const {
|
||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
|
||||
return std::max<int>(M / (wpt[0] * mmaInstrM), 1);
|
||||
}
|
||||
|
||||
// \param operand is either $a or $b's type.
|
||||
inline int getNumRepN(Type operand, int N) const {
|
||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
|
||||
return std::max<int>(N / (wpt[1] * mmaInstrN), 1);
|
||||
}
|
||||
|
||||
// \param operand is either $a or $b's type.
|
||||
inline int getNumRepK(Type operand, int K) const {
|
||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
|
||||
return std::max<int>(K / mmaInstrK, 1);
|
||||
}
|
||||
|
||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||
Value loadA() {
|
||||
Value loadA(Value tensor, Value llTensor) const {
|
||||
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shape = aTensorTy.getShape();
|
||||
|
||||
ValueTable ha;
|
||||
std::function<void(int, int)> loadFn;
|
||||
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
|
||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
|
||||
int numRepM = getNumRepM(aTensorTy, shape[0]);
|
||||
int numRepK = getNumRepK(aTensorTy, shape[1]);
|
||||
|
||||
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
||||
// load from smem
|
||||
loadFn = getLoadMatrixFn(
|
||||
A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
||||
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
@@ -2770,10 +2792,17 @@ struct MMA16816ConversionHelper {
|
||||
}
|
||||
|
||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||
Value loadB() {
|
||||
Value loadB(Value tensor, Value llTensor) {
|
||||
ValueTable hb;
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(tensorTy);
|
||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(tensorTy);
|
||||
int numRepK = getNumRepK(tensorTy, shape[0]);
|
||||
int numRepN = getNumRepN(tensorTy, shape[1]);
|
||||
|
||||
auto loadFn = getLoadMatrixFn(
|
||||
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||
|
||||
@@ -2789,24 +2818,47 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
// Loading $c from smem(?) to registers, returns a Value.
|
||||
// NOTE Only SplatLike tensor is supported now.
|
||||
Value loadC() {
|
||||
Value loadC(Value tensor) const {
|
||||
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
|
||||
// shared layout or blocked layout, we will support them by expanding
|
||||
// convert_layout.
|
||||
auto hc = helper.loadSplatLikeC(C, loc, rewriter);
|
||||
auto hc = helper.loadSplatLikeC(tensor, loc, rewriter);
|
||||
assert(hc.size() == 4UL && "Only splat-like C is supported now");
|
||||
return hc[0];
|
||||
}
|
||||
|
||||
// Conduct the Dot conversion.
|
||||
// Input the \param a, \param b, \param c, all of them are result of loading.
|
||||
LogicalResult convertDot(Value a, Value b, Value c) {
|
||||
ValueTable ha = getValuesFromDotOperandLayoutStruct(a, numRepM, numRepK);
|
||||
// \param a, \param b, \param c and \param d are DotOp operands.
|
||||
// \param loadedA, \param loadedB, \param loadedC, all of them are result of
|
||||
// loading.
|
||||
LogicalResult convertDot(Value a, Value b, Value c, Value d, Value loadedA,
|
||||
Value loadedB, Value loadedC, DotOp op,
|
||||
DotOpAdaptor adaptor) const {
|
||||
helper.deduceMmaType(op);
|
||||
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||
auto cTensorTy = c.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
|
||||
auto aShape = aTensorTy.getShape();
|
||||
auto dShape = dTensorTy.getShape();
|
||||
|
||||
int NK = aShape[1];
|
||||
// shape / shape_per_cta
|
||||
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
|
||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
|
||||
int numRepM = getNumRepM(aTensorTy, dShape[0]);
|
||||
int numRepN = getNumRepN(aTensorTy, dShape[1]);
|
||||
int numRepK = getNumRepK(aTensorTy, aShape[1]);
|
||||
|
||||
ValueTable ha =
|
||||
getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK);
|
||||
ValueTable hb = getValuesFromDotOperandLayoutStruct(
|
||||
b, std::max(numRepN / 2, 1), numRepK);
|
||||
loadedB, std::max(numRepN / 2, 1), numRepK);
|
||||
|
||||
const int fcSize = 4 * numRepM * numRepN;
|
||||
SmallVector<Value> fc(fcSize, c);
|
||||
SmallVector<Value> fc(fcSize, loadedC);
|
||||
|
||||
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
|
||||
unsigned colsPerThread = numRepN * 2;
|
||||
@@ -2855,10 +2907,10 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
private:
|
||||
std::function<void(int, int)>
|
||||
getLoadMatrixFn(Value tensor, Value llTensor, int wpt, int kOrder,
|
||||
ArrayRef<int> instrShape, ArrayRef<int> matShape,
|
||||
Value warpId, ValueTable &vals) {
|
||||
|
||||
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
|
||||
int wpt, int kOrder, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, Value warpId,
|
||||
ValueTable &vals) const {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
// We assumes that the input operand of Dot should be from shared layout.
|
||||
// TODO(Superjomn) Consider other layouts if needed later.
|
||||
@@ -2928,7 +2980,7 @@ private:
|
||||
// i \in [0, n0) and j \in [0, n1)
|
||||
// There should be \param n0 * \param n1 elements in the output Struct.
|
||||
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
|
||||
int n1) {
|
||||
int n1) const {
|
||||
std::vector<Value> elems;
|
||||
for (unsigned m = 0; m < n0; ++m)
|
||||
for (unsigned k = 0; k < n1; ++k) {
|
||||
@@ -2940,7 +2992,7 @@ private:
|
||||
|
||||
assert(!elems.empty());
|
||||
|
||||
Type fp16Ty = aTensorTy.getElementType();
|
||||
Type fp16Ty = type::f16Ty(ctx);
|
||||
Type fp16x2Ty = vec_ty(fp16Ty, 2);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems.size(), fp16x2Ty));
|
||||
@@ -2948,7 +3000,8 @@ private:
|
||||
return result;
|
||||
}
|
||||
|
||||
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1) {
|
||||
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0,
|
||||
int n1) const {
|
||||
auto elems = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct(
|
||||
loc, value, rewriter);
|
||||
|
||||
@@ -2966,18 +3019,79 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
LogicalResult
|
||||
DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
||||
LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
MMA16816ConversionHelper mmaHelper(op, getThreadId(rewriter, loc), adapter,
|
||||
Value src = op.src();
|
||||
Value dst = op.result();
|
||||
auto srcTensorTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
|
||||
|
||||
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto dotOperandLayout =
|
||||
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
MmaEncodingAttr mmaLayout =
|
||||
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
|
||||
assert(mmaLayout);
|
||||
|
||||
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
||||
rewriter, getTypeConverter(), op.getLoc());
|
||||
|
||||
Value res;
|
||||
if (dotOperandLayout.getOpIdx() == 0) {
|
||||
// operand $a
|
||||
res = mmaHelper.loadA(src, adaptor.src());
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) {
|
||||
// operand $b
|
||||
res = mmaHelper.loadB(src, adaptor.src());
|
||||
} else if (dotOperandLayout.getOpIdx() == 2) {
|
||||
// operand $c
|
||||
res = mmaHelper.loadC(src);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
auto mmaLayout = op.getResult()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<MmaEncodingAttr>();
|
||||
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
||||
rewriter, getTypeConverter(), loc);
|
||||
|
||||
auto A = mmaHelper.loadA();
|
||||
auto B = mmaHelper.loadB();
|
||||
auto C = mmaHelper.loadC();
|
||||
Value A = op.a();
|
||||
Value B = op.b();
|
||||
Value C = op.c();
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
|
||||
return mmaHelper.convertDot(A, B, C);
|
||||
Value loadedA, loadedB, loadedC;
|
||||
// We support two kinds of operand layouts: 1. both $a, $b are dot_operand
|
||||
// layout, 2. both of them are shared layout.
|
||||
if (ATensorTy.getEncoding().isa<DotOperandEncodingAttr>()) {
|
||||
assert(BTensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
|
||||
"Both $a and %b should be DotOperand layout.");
|
||||
loadedA = adaptor.a();
|
||||
loadedB = adaptor.b();
|
||||
} else {
|
||||
loadedA = mmaHelper.loadA(op.a(), adaptor.a());
|
||||
loadedB = mmaHelper.loadB(op.b(), adaptor.b());
|
||||
}
|
||||
|
||||
// TODO[Superjomn]: Process C as a mma layout.
|
||||
// Currently, C is simply treated as a Splat Op, and the data layout is not
|
||||
// mattered.
|
||||
loadedC = mmaHelper.loadC(op.c());
|
||||
|
||||
return mmaHelper.convertDot(A, B, C, op.d(), loadedA, loadedB, loadedC, op,
|
||||
adaptor);
|
||||
}
|
||||
|
||||
/// ====================== mma codegen end ============================
|
||||
|
@@ -714,3 +714,27 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
|
||||
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
%a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a>
|
||||
%b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b>
|
||||
|
||||
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma>
|
||||
%38 = triton_gpu.convert_layout %28 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked>
|
||||
|
||||
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>, #blocked>
|
||||
%36 = tt.broadcast %30 : (tensor<128x1x!tt.ptr<f32>, #blocked>) -> tensor<128x256x!tt.ptr<f32>, #blocked>
|
||||
tt.store %36, %38 : tensor<128x256xf32, #blocked>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user