|
|
|
@@ -29,6 +29,7 @@
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
using namespace mlir::triton;
|
|
|
|
|
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
|
|
|
|
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
|
|
|
|
using ::mlir::triton::gpu::getElemsPerThread;
|
|
|
|
|
using ::mlir::triton::gpu::getOrder;
|
|
|
|
|
using ::mlir::triton::gpu::getShapePerCTA;
|
|
|
|
@@ -1565,6 +1566,10 @@ public:
|
|
|
|
|
dstLayout.isa<SharedEncodingAttr>()) {
|
|
|
|
|
return lowerBlockedToShared(op, adaptor, rewriter);
|
|
|
|
|
}
|
|
|
|
|
if (srcLayout.isa<SharedEncodingAttr>() &&
|
|
|
|
|
dstLayout.isa<DotOperandEncodingAttr>()) {
|
|
|
|
|
return lowerSharedToDotOperand(op, adaptor, rewriter);
|
|
|
|
|
}
|
|
|
|
|
if ((!srcLayout.isa<BlockedEncodingAttr>() &&
|
|
|
|
|
!srcLayout.isa<MmaEncodingAttr>()) ||
|
|
|
|
|
(!dstLayout.isa<BlockedEncodingAttr>() &&
|
|
|
|
@@ -1572,6 +1577,7 @@ public:
|
|
|
|
|
// TODO: to be implemented
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -1609,6 +1615,11 @@ private:
|
|
|
|
|
LogicalResult lowerBlockedToShared(triton::gpu::ConvertLayoutOp op,
|
|
|
|
|
OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const;
|
|
|
|
|
|
|
|
|
|
// shared -> mma_operand
|
|
|
|
|
LogicalResult
|
|
|
|
|
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void ConvertLayoutOpConversion::processReplica(
|
|
|
|
@@ -1915,6 +1926,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|
|
|
|
rewriter.replaceOp(op, smemBase);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// ====================== dot codegen begin ==========================
|
|
|
|
|
|
|
|
|
|
// Data loader for mma.16816 instruction.
|
|
|
|
@@ -2383,16 +2395,16 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
// Convert to mma.m16n8k16
|
|
|
|
|
LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adapter,
|
|
|
|
|
LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const;
|
|
|
|
|
/// Convert to mma.m8n8k4
|
|
|
|
|
LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adapter,
|
|
|
|
|
LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
assert(false && "Not implemented yet.");
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adapter,
|
|
|
|
|
LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
assert(false && "Not implemented yet.");
|
|
|
|
|
return failure();
|
|
|
|
@@ -2402,28 +2414,18 @@ private:
|
|
|
|
|
struct DotOpConversionHelper {
|
|
|
|
|
using TensorCoreType = DotOpConversion::TensorCoreType;
|
|
|
|
|
|
|
|
|
|
Value A, B, C, D;
|
|
|
|
|
MmaEncodingAttr mmaLayout;
|
|
|
|
|
RankedTensorType ATensorTy, BTensorTy, DTensorTy;
|
|
|
|
|
MLIRContext *ctx{};
|
|
|
|
|
|
|
|
|
|
explicit DotOpConversionHelper(DotOp dot)
|
|
|
|
|
: dot(dot), mmaType(getMmaType(dot)) {
|
|
|
|
|
A = dot.a();
|
|
|
|
|
B = dot.b();
|
|
|
|
|
C = dot.c();
|
|
|
|
|
D = dot.d();
|
|
|
|
|
ctx = dot->getContext();
|
|
|
|
|
mmaLayout = C.getType()
|
|
|
|
|
.cast<RankedTensorType>()
|
|
|
|
|
.getEncoding()
|
|
|
|
|
.cast<MmaEncodingAttr>();
|
|
|
|
|
explicit DotOpConversionHelper(MmaEncodingAttr mmaLayout)
|
|
|
|
|
: mmaLayout(mmaLayout) {
|
|
|
|
|
ctx = mmaLayout.getContext();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Load SplatLike C which contains a constVal. It simply returns 4 fp32
|
|
|
|
|
// constVal.
|
|
|
|
|
SmallVector<Value> loadSplatLikeC(Value C, Location loc,
|
|
|
|
|
ConversionPatternRewriter &rewriter) {
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
assert(isSplatLike(C));
|
|
|
|
|
|
|
|
|
|
int numRes = getMmaInstrShape()[0] * getMmaInstrShape()[1] / 32;
|
|
|
|
@@ -2451,6 +2453,11 @@ struct DotOpConversionHelper {
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void deduceMmaType(DotOp op) const { mmaType = getMmaType(op); }
|
|
|
|
|
void deduceMmaType(Type operandTy) const {
|
|
|
|
|
mmaType = getTensorCoreTypeFromOperand(operandTy);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Type getShemPtrTy() const {
|
|
|
|
|
switch (mmaType) {
|
|
|
|
|
case TensorCoreType::FP32_FP16_FP16_FP32:
|
|
|
|
@@ -2554,6 +2561,22 @@ struct DotOpConversionHelper {
|
|
|
|
|
return mmaMatShape.at(mmaType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Deduce the TensorCoreType from either $a or $b's type. This method is not
|
|
|
|
|
// safe, but we cannot get the DotOp in some getmaMatShape usage case.
|
|
|
|
|
TensorCoreType getTensorCoreTypeFromOperand(Type operandTy) const {
|
|
|
|
|
auto tensorTy = operandTy.cast<RankedTensorType>();
|
|
|
|
|
auto elemTy = tensorTy.getElementType();
|
|
|
|
|
if (elemTy.isF16())
|
|
|
|
|
return TensorCoreType::FP32_FP16_FP16_FP32;
|
|
|
|
|
if (elemTy.isF32())
|
|
|
|
|
return TensorCoreType::FP32_TF32_TF32_FP32;
|
|
|
|
|
if (elemTy.isBF16())
|
|
|
|
|
return TensorCoreType::FP32_BF16_BF16_FP32;
|
|
|
|
|
if (elemTy.isInteger(8))
|
|
|
|
|
return TensorCoreType::INT32_INT8_INT8_INT32;
|
|
|
|
|
return TensorCoreType::NOT_APPLICABLE;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int getVec() const {
|
|
|
|
|
assert(mmaType != TensorCoreType::NOT_APPLICABLE &&
|
|
|
|
|
"Unknown mma type found.");
|
|
|
|
@@ -2593,7 +2616,7 @@ struct DotOpConversionHelper {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
TensorCoreType mmaType;
|
|
|
|
|
mutable TensorCoreType mmaType{TensorCoreType::NOT_APPLICABLE};
|
|
|
|
|
|
|
|
|
|
// Used on nvidia GPUs mma layout .version == 2
|
|
|
|
|
// Refer to
|
|
|
|
@@ -2655,9 +2678,6 @@ private:
|
|
|
|
|
{TensorCoreType::INT32_INT4_INT4_INT32, 32},
|
|
|
|
|
{TensorCoreType::INT32_INT8_INT8_INT32, 16},
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
DotOp dot;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// This class helps to adapt the existing DotOpConversion to the latest
|
|
|
|
@@ -2666,21 +2686,12 @@ private:
|
|
|
|
|
// 1. loading the specific operand matrix(for $a, $b, $c) from smem
|
|
|
|
|
// 2. passing the loaded value and perform the mma codegen
|
|
|
|
|
struct MMA16816ConversionHelper {
|
|
|
|
|
Value A, B, C, D;
|
|
|
|
|
RankedTensorType aTensorTy, bTensorTy, dTensorTy;
|
|
|
|
|
ArrayRef<int64_t> aShape, bShape, dShape;
|
|
|
|
|
MmaEncodingAttr mmaLayout;
|
|
|
|
|
ArrayRef<unsigned int> wpt;
|
|
|
|
|
|
|
|
|
|
int mmaInstrM{-1}, mmaInstrN{-1}, mmaInstrK{-1};
|
|
|
|
|
int matShapeM{-1}, matShapeN{-1}, matShapeK{-1};
|
|
|
|
|
int numRepM{-1}, numRepN{-1}, numRepK{-1};
|
|
|
|
|
Value thread, lane, warp, warpMN, warpN, warpM;
|
|
|
|
|
size_t aElemBytes{}, bElemBytes{};
|
|
|
|
|
|
|
|
|
|
DotOpConversionHelper helper;
|
|
|
|
|
triton::DotOp op;
|
|
|
|
|
DotOpAdaptor adapter;
|
|
|
|
|
ConversionPatternRewriter &rewriter;
|
|
|
|
|
TypeConverter *typeConverter;
|
|
|
|
|
Location loc;
|
|
|
|
@@ -2688,64 +2699,75 @@ struct MMA16816ConversionHelper {
|
|
|
|
|
|
|
|
|
|
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
|
|
|
|
|
|
|
|
|
|
MMA16816ConversionHelper(triton::DotOp op, Value thread, DotOpAdaptor adapter,
|
|
|
|
|
MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread,
|
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
|
|
TypeConverter *typeConverter, Location loc)
|
|
|
|
|
: helper(op), op(op), adapter(adapter), rewriter(rewriter),
|
|
|
|
|
typeConverter(typeConverter), loc(loc), ctx(op.getContext()),
|
|
|
|
|
: mmaLayout(mmaLayout), helper(mmaLayout), rewriter(rewriter),
|
|
|
|
|
typeConverter(typeConverter), loc(loc), ctx(mmaLayout.getContext()),
|
|
|
|
|
thread(thread) {
|
|
|
|
|
A = op.a();
|
|
|
|
|
B = op.b();
|
|
|
|
|
C = op.c();
|
|
|
|
|
D = op.getResult();
|
|
|
|
|
|
|
|
|
|
aTensorTy = A.getType().cast<RankedTensorType>();
|
|
|
|
|
bTensorTy = B.getType().cast<RankedTensorType>();
|
|
|
|
|
dTensorTy = D.getType().cast<RankedTensorType>();
|
|
|
|
|
|
|
|
|
|
aShape = aTensorTy.getShape();
|
|
|
|
|
bShape = bTensorTy.getShape();
|
|
|
|
|
dShape = dTensorTy.getShape();
|
|
|
|
|
|
|
|
|
|
mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
|
|
|
|
|
|
|
|
|
|
wpt = mmaLayout.getWarpsPerCTA();
|
|
|
|
|
|
|
|
|
|
auto mmaInstrShape = helper.getMmaInstrShape();
|
|
|
|
|
mmaInstrM = mmaInstrShape[0];
|
|
|
|
|
mmaInstrN = mmaInstrShape[1];
|
|
|
|
|
mmaInstrK = mmaInstrShape[2];
|
|
|
|
|
|
|
|
|
|
auto matShape = helper.getMmaMatShape();
|
|
|
|
|
matShapeM = matShape[0];
|
|
|
|
|
matShapeN = matShape[1];
|
|
|
|
|
matShapeK = matShape[2];
|
|
|
|
|
|
|
|
|
|
int NK = aShape[1];
|
|
|
|
|
// shape / shape_per_cta
|
|
|
|
|
numRepM = std::max<int>(dShape[0] / (wpt[0] * mmaInstrM), 1);
|
|
|
|
|
numRepN = std::max<int>(dShape[1] / (wpt[1] * mmaInstrN), 1);
|
|
|
|
|
numRepK = std::max<int>(NK / mmaInstrK, 1);
|
|
|
|
|
|
|
|
|
|
Value _32 = i32_val(32);
|
|
|
|
|
lane = urem(thread, _32);
|
|
|
|
|
warp = udiv(thread, _32);
|
|
|
|
|
warpMN = udiv(warp, i32_val(wpt[0]));
|
|
|
|
|
warpM = urem(warp, i32_val(wpt[0]));
|
|
|
|
|
warpN = urem(warpMN, i32_val(wpt[1]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
aElemBytes = aTensorTy.getElementTypeBitWidth() / 8;
|
|
|
|
|
bElemBytes = bTensorTy.getElementTypeBitWidth() / 8;
|
|
|
|
|
// Get the mmaInstrShape from either $a or $b.
|
|
|
|
|
std::tuple<int, int, int> getMmaInstrShape(Type operand) const {
|
|
|
|
|
helper.deduceMmaType(operand);
|
|
|
|
|
auto mmaInstrShape = helper.getMmaInstrShape();
|
|
|
|
|
int mmaInstrM = mmaInstrShape[0];
|
|
|
|
|
int mmaInstrN = mmaInstrShape[1];
|
|
|
|
|
int mmaInstrK = mmaInstrShape[2];
|
|
|
|
|
return std::make_tuple(mmaInstrM, mmaInstrN, mmaInstrK);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<int, int, int> getMmaMatShape(Type operand) const {
|
|
|
|
|
helper.deduceMmaType(operand);
|
|
|
|
|
auto matShape = helper.getMmaMatShape();
|
|
|
|
|
int matShapeM = matShape[0];
|
|
|
|
|
int matShapeN = matShape[1];
|
|
|
|
|
int matShapeK = matShape[2];
|
|
|
|
|
return std::make_tuple(matShapeM, matShapeN, matShapeK);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// \param operand is either $a or $b's type.
|
|
|
|
|
inline int getNumRepM(Type operand, int M) const {
|
|
|
|
|
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
|
|
|
|
|
return std::max<int>(M / (wpt[0] * mmaInstrM), 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// \param operand is either $a or $b's type.
|
|
|
|
|
inline int getNumRepN(Type operand, int N) const {
|
|
|
|
|
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
|
|
|
|
|
return std::max<int>(N / (wpt[1] * mmaInstrN), 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// \param operand is either $a or $b's type.
|
|
|
|
|
inline int getNumRepK(Type operand, int K) const {
|
|
|
|
|
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
|
|
|
|
|
return std::max<int>(K / mmaInstrK, 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Loading $a from smem to registers, returns a LLVM::Struct.
|
|
|
|
|
Value loadA() {
|
|
|
|
|
Value loadA(Value tensor, Value llTensor) const {
|
|
|
|
|
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
|
|
|
|
auto shape = aTensorTy.getShape();
|
|
|
|
|
|
|
|
|
|
ValueTable ha;
|
|
|
|
|
std::function<void(int, int)> loadFn;
|
|
|
|
|
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
|
|
|
|
|
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
|
|
|
|
|
int numRepM = getNumRepM(aTensorTy, shape[0]);
|
|
|
|
|
int numRepK = getNumRepK(aTensorTy, shape[1]);
|
|
|
|
|
|
|
|
|
|
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
|
|
|
|
// load from smem
|
|
|
|
|
loadFn = getLoadMatrixFn(
|
|
|
|
|
A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
|
|
|
|
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
|
|
|
|
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
|
|
|
|
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
|
|
|
|
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
|
|
|
@@ -2770,10 +2792,17 @@ struct MMA16816ConversionHelper {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Loading $b from smem to registers, returns a LLVM::Struct.
|
|
|
|
|
Value loadB() {
|
|
|
|
|
Value loadB(Value tensor, Value llTensor) {
|
|
|
|
|
ValueTable hb;
|
|
|
|
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
|
|
|
|
auto shape = tensorTy.getShape();
|
|
|
|
|
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(tensorTy);
|
|
|
|
|
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(tensorTy);
|
|
|
|
|
int numRepK = getNumRepK(tensorTy, shape[0]);
|
|
|
|
|
int numRepN = getNumRepN(tensorTy, shape[1]);
|
|
|
|
|
|
|
|
|
|
auto loadFn = getLoadMatrixFn(
|
|
|
|
|
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
|
|
|
|
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
|
|
|
|
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
|
|
|
|
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
|
|
|
|
|
|
|
|
@@ -2789,24 +2818,47 @@ struct MMA16816ConversionHelper {
|
|
|
|
|
|
|
|
|
|
// Loading $c from smem(?) to registers, returns a Value.
|
|
|
|
|
// NOTE Only SplatLike tensor is supported now.
|
|
|
|
|
Value loadC() {
|
|
|
|
|
Value loadC(Value tensor) const {
|
|
|
|
|
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
|
|
|
|
|
// shared layout or blocked layout, we will support them by expanding
|
|
|
|
|
// convert_layout.
|
|
|
|
|
auto hc = helper.loadSplatLikeC(C, loc, rewriter);
|
|
|
|
|
auto hc = helper.loadSplatLikeC(tensor, loc, rewriter);
|
|
|
|
|
assert(hc.size() == 4UL && "Only splat-like C is supported now");
|
|
|
|
|
return hc[0];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Conduct the Dot conversion.
|
|
|
|
|
// Input the \param a, \param b, \param c, all of them are result of loading.
|
|
|
|
|
LogicalResult convertDot(Value a, Value b, Value c) {
|
|
|
|
|
ValueTable ha = getValuesFromDotOperandLayoutStruct(a, numRepM, numRepK);
|
|
|
|
|
// \param a, \param b, \param c and \param d are DotOp operands.
|
|
|
|
|
// \param loadedA, \param loadedB, \param loadedC, all of them are result of
|
|
|
|
|
// loading.
|
|
|
|
|
LogicalResult convertDot(Value a, Value b, Value c, Value d, Value loadedA,
|
|
|
|
|
Value loadedB, Value loadedC, DotOp op,
|
|
|
|
|
DotOpAdaptor adaptor) const {
|
|
|
|
|
helper.deduceMmaType(op);
|
|
|
|
|
|
|
|
|
|
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
|
|
|
|
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
|
|
|
|
auto cTensorTy = c.getType().cast<RankedTensorType>();
|
|
|
|
|
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
|
|
|
|
|
|
|
|
|
auto aShape = aTensorTy.getShape();
|
|
|
|
|
auto dShape = dTensorTy.getShape();
|
|
|
|
|
|
|
|
|
|
int NK = aShape[1];
|
|
|
|
|
// shape / shape_per_cta
|
|
|
|
|
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
|
|
|
|
|
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
|
|
|
|
|
int numRepM = getNumRepM(aTensorTy, dShape[0]);
|
|
|
|
|
int numRepN = getNumRepN(aTensorTy, dShape[1]);
|
|
|
|
|
int numRepK = getNumRepK(aTensorTy, aShape[1]);
|
|
|
|
|
|
|
|
|
|
ValueTable ha =
|
|
|
|
|
getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK);
|
|
|
|
|
ValueTable hb = getValuesFromDotOperandLayoutStruct(
|
|
|
|
|
b, std::max(numRepN / 2, 1), numRepK);
|
|
|
|
|
loadedB, std::max(numRepN / 2, 1), numRepK);
|
|
|
|
|
|
|
|
|
|
const int fcSize = 4 * numRepM * numRepN;
|
|
|
|
|
SmallVector<Value> fc(fcSize, c);
|
|
|
|
|
SmallVector<Value> fc(fcSize, loadedC);
|
|
|
|
|
|
|
|
|
|
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
|
|
|
|
|
unsigned colsPerThread = numRepN * 2;
|
|
|
|
@@ -2855,10 +2907,10 @@ struct MMA16816ConversionHelper {
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::function<void(int, int)>
|
|
|
|
|
getLoadMatrixFn(Value tensor, Value llTensor, int wpt, int kOrder,
|
|
|
|
|
ArrayRef<int> instrShape, ArrayRef<int> matShape,
|
|
|
|
|
Value warpId, ValueTable &vals) {
|
|
|
|
|
|
|
|
|
|
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
|
|
|
|
|
int wpt, int kOrder, ArrayRef<int> instrShape,
|
|
|
|
|
ArrayRef<int> matShape, Value warpId,
|
|
|
|
|
ValueTable &vals) const {
|
|
|
|
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
|
|
|
|
// We assumes that the input operand of Dot should be from shared layout.
|
|
|
|
|
// TODO(Superjomn) Consider other layouts if needed later.
|
|
|
|
@@ -2928,7 +2980,7 @@ private:
|
|
|
|
|
// i \in [0, n0) and j \in [0, n1)
|
|
|
|
|
// There should be \param n0 * \param n1 elements in the output Struct.
|
|
|
|
|
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
|
|
|
|
|
int n1) {
|
|
|
|
|
int n1) const {
|
|
|
|
|
std::vector<Value> elems;
|
|
|
|
|
for (unsigned m = 0; m < n0; ++m)
|
|
|
|
|
for (unsigned k = 0; k < n1; ++k) {
|
|
|
|
@@ -2940,7 +2992,7 @@ private:
|
|
|
|
|
|
|
|
|
|
assert(!elems.empty());
|
|
|
|
|
|
|
|
|
|
Type fp16Ty = aTensorTy.getElementType();
|
|
|
|
|
Type fp16Ty = type::f16Ty(ctx);
|
|
|
|
|
Type fp16x2Ty = vec_ty(fp16Ty, 2);
|
|
|
|
|
Type structTy = LLVM::LLVMStructType::getLiteral(
|
|
|
|
|
ctx, SmallVector<Type>(elems.size(), fp16x2Ty));
|
|
|
|
@@ -2948,7 +3000,8 @@ private:
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1) {
|
|
|
|
|
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0,
|
|
|
|
|
int n1) const {
|
|
|
|
|
auto elems = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct(
|
|
|
|
|
loc, value, rewriter);
|
|
|
|
|
|
|
|
|
@@ -2966,18 +3019,79 @@ private:
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
|
|
|
|
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
Value src = op.src();
|
|
|
|
|
Value dst = op.result();
|
|
|
|
|
auto srcTensorTy = src.getType().cast<RankedTensorType>();
|
|
|
|
|
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
|
|
|
|
|
|
|
|
|
|
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
|
|
|
auto dotOperandLayout =
|
|
|
|
|
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
|
|
|
|
MmaEncodingAttr mmaLayout =
|
|
|
|
|
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
|
|
|
|
|
assert(mmaLayout);
|
|
|
|
|
|
|
|
|
|
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
|
|
|
|
rewriter, getTypeConverter(), op.getLoc());
|
|
|
|
|
|
|
|
|
|
Value res;
|
|
|
|
|
if (dotOperandLayout.getOpIdx() == 0) {
|
|
|
|
|
// operand $a
|
|
|
|
|
res = mmaHelper.loadA(src, adaptor.src());
|
|
|
|
|
} else if (dotOperandLayout.getOpIdx() == 1) {
|
|
|
|
|
// operand $b
|
|
|
|
|
res = mmaHelper.loadB(src, adaptor.src());
|
|
|
|
|
} else if (dotOperandLayout.getOpIdx() == 2) {
|
|
|
|
|
// operand $c
|
|
|
|
|
res = mmaHelper.loadC(src);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, res);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
|
|
|
|
DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
MMA16816ConversionHelper mmaHelper(op, getThreadId(rewriter, loc), adapter,
|
|
|
|
|
auto mmaLayout = op.getResult()
|
|
|
|
|
.getType()
|
|
|
|
|
.cast<RankedTensorType>()
|
|
|
|
|
.getEncoding()
|
|
|
|
|
.cast<MmaEncodingAttr>();
|
|
|
|
|
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
|
|
|
|
rewriter, getTypeConverter(), loc);
|
|
|
|
|
|
|
|
|
|
auto A = mmaHelper.loadA();
|
|
|
|
|
auto B = mmaHelper.loadB();
|
|
|
|
|
auto C = mmaHelper.loadC();
|
|
|
|
|
Value A = op.a();
|
|
|
|
|
Value B = op.b();
|
|
|
|
|
Value C = op.c();
|
|
|
|
|
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
|
|
|
|
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
|
|
|
|
|
|
|
|
|
return mmaHelper.convertDot(A, B, C);
|
|
|
|
|
Value loadedA, loadedB, loadedC;
|
|
|
|
|
// We support two kinds of operand layouts: 1. both $a, $b are dot_operand
|
|
|
|
|
// layout, 2. both of them are shared layout.
|
|
|
|
|
if (ATensorTy.getEncoding().isa<DotOperandEncodingAttr>()) {
|
|
|
|
|
assert(BTensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
|
|
|
|
|
"Both $a and %b should be DotOperand layout.");
|
|
|
|
|
loadedA = adaptor.a();
|
|
|
|
|
loadedB = adaptor.b();
|
|
|
|
|
} else {
|
|
|
|
|
loadedA = mmaHelper.loadA(op.a(), adaptor.a());
|
|
|
|
|
loadedB = mmaHelper.loadB(op.b(), adaptor.b());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO[Superjomn]: Process C as a mma layout.
|
|
|
|
|
// Currently, C is simply treated as a Splat Op, and the data layout is not
|
|
|
|
|
// mattered.
|
|
|
|
|
loadedC = mmaHelper.loadC(op.c());
|
|
|
|
|
|
|
|
|
|
return mmaHelper.convertDot(A, B, C, op.d(), loadedA, loadedB, loadedC, op,
|
|
|
|
|
adaptor);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// ====================== mma codegen end ============================
|
|
|
|
|