[triton-mlir][BACKEND] decouple loading from mma codegen in dot conversion (#764)

This PR decouples the operand loading from the mma codegen to make it
ready for the ongoing `DotOperandEncodingAttr` migration.
The existing DotOp conversion is composed of the following two
procedures:
1. Loading the $a,$b,$c operand from smem to registers
2. Conducting the MMA instruction codegen.

While in the latest design, the 1st stage should be part of the
`convert_layout(shared_layout) -> dot_operand_layout`, that's why the
decoupling is necessary.

Some details, this PR introduces a `MMA16816ConversionHelper` class, it
has `loadA`, `loadB` and `loadC` methods to help load $a, $b and $c from
smem to registers, both `loadA` and `loadB` methods returns a
`LLVM::Struct` which should be compatible with the new
`DotOperandEncodingAttr` conversion.

The conversion layout for $a and $b is as follows:
```c++
// The layout is a list of Value with coordinate of (i,j), the order is as
// the follows:
// [
//  (0,0), (0,1), (1,0), (1,1), # i=0, j=0
//  (0,2), (0,3), (1,2), (1,3), # i=0, j=1
//  (0,4), (0,5), (1,4), (1,5), # i=0, j=2
//  ...
//  (2,0), (2,1), (3,0), (3,1), # i=1, j=0
//  (2,2), (2,3), (3,2), (3,3), # i=1, j=1
//  (2,4), (2,5), (2,4), (2,5), # i=1, j=2
//  ...
// ]
// i \in [0, n0) and j \in [0, n1)
```
In the `convertDot` method, it takes loaded $a, $b, $c($a and $b are
type of `LLVM::Struct` while $c is a scalar Value), extract the elements
from `LLVM::Struct` following the layout above, and pass the elements to
MMA inline asm.
This commit is contained in:
Yan Chunwei
2022-10-12 10:45:17 +08:00
committed by GitHub
parent 623c99609f
commit 1baa4e125f

View File

@@ -337,12 +337,13 @@ static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
struct ConvertTritonGPUOpToLLVMPatternBase {
static SmallVector<Value>
getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems,
getElementsFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
SmallVector<Value> results(elems);
for (unsigned i = 0; i < elems; ++i) {
Type type =
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody()[i];
ArrayRef<Type> types =
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
SmallVector<Value> results(types.size());
for (unsigned i = 0; i < types.size(); ++i) {
Type type = types[i];
results[i] = extract_val(type, llvmStruct, rewriter.getI64ArrayAttr(i));
}
return results;
@@ -715,8 +716,7 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
auto shape = value.getType().cast<RankedTensorType>().getShape();
// Here, we assume that all inputs should have a blockedLayout
unsigned valueElems = layout.getElemsPerThread(shape);
auto valueVals = getElementsFromStruct(loc, llValue, valueElems, rewriter);
auto valueVals = getElementsFromStruct(loc, llValue, rewriter);
return valueVals;
}
@@ -977,7 +977,7 @@ struct BroadcastOpConversion
unsigned srcElems = srcLayout.getElemsPerThread(srcShape);
auto elemTy = resultTy.getElementType();
auto srcVals = getElementsFromStruct(loc, src, srcElems, rewriter);
auto srcVals = getElementsFromStruct(loc, src, rewriter);
unsigned resultElems = resultLayout.getElemsPerThread(resultShape);
SmallVector<Value> resultVals(resultElems);
for (unsigned i = 0; i < srcElems; ++i) {
@@ -1025,8 +1025,7 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto vals =
this->getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
auto vals = this->getElementsFromStruct(loc, adaptor.src(), rewriter);
Value view = getStructFromElements(loc, vals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
@@ -1308,9 +1307,8 @@ struct AddPtrOpConversion
getTypeConverter()->convertType(resultTensorTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), elems, rewriter);
auto offsets =
getElementsFromStruct(loc, adaptor.offset(), elems, rewriter);
auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), rewriter);
auto offsets = getElementsFromStruct(loc, adaptor.offset(), rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = gep(elemTy, ptrs[i], offsets[i]);
@@ -1416,9 +1414,9 @@ public:
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto lhss = this->getElementsFromStruct(loc, concreteThis->getLhs(adaptor),
elems, rewriter);
rewriter);
auto rhss = this->getElementsFromStruct(loc, concreteThis->getRhs(adaptor),
elems, rewriter);
rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
@@ -1761,7 +1759,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
// Potentially we need to store for multiple CTAs in this replication
unsigned accumNumReplicates = product<unsigned>(numReplicates);
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
@@ -1824,7 +1822,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
unsigned perPhase = dstSharedLayout.getPerPhase();
unsigned maxPhase = dstSharedLayout.getMaxPhase();
unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape);
auto inVals = getElementsFromStruct(loc, adaptor.src(), numElems, rewriter);
auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned srcAccumSizeInThreads =
product<unsigned>(srcBlockedLayout.getSizePerThread());
auto elemTy = srcTy.getElementType();
@@ -2661,75 +2659,205 @@ private:
DotOp dot;
};
LogicalResult
DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
MLIRContext *ctx = op->getContext();
// D = A * B + C
Value A = op.a();
Value B = op.b();
Value C = op.c();
Value D = op.getResult();
bool allowTF32 = op.allowTF32();
// This class helps to adapt the existing DotOpConversion to the latest
// DotOpOperand layout design. It decouples the exising implementation to two
// parts:
// 1. loading the specific operand matrix(for $a, $b, $c) from smem
// 2. passing the loaded value and perform the mma codegen
struct MMA16816ConversionHelper {
Value A, B, C, D;
RankedTensorType aTensorTy, bTensorTy, dTensorTy;
ArrayRef<int64_t> aShape, bShape, dShape;
MmaEncodingAttr mmaLayout;
ArrayRef<unsigned int> wpt;
auto aTensorTy = A.getType().cast<RankedTensorType>();
auto bTensorTy = B.getType().cast<RankedTensorType>();
auto dTensorTy = D.getType().cast<RankedTensorType>();
int mmaInstrM{-1}, mmaInstrN{-1}, mmaInstrK{-1};
int matShapeM{-1}, matShapeN{-1}, matShapeK{-1};
int numRepM{-1}, numRepN{-1}, numRepK{-1};
Value thread, lane, warp, warpMN, warpN, warpM;
size_t aElemBytes{}, bElemBytes{};
auto aShape = aTensorTy.getShape();
auto bShape = bTensorTy.getShape();
auto dShape = dTensorTy.getShape();
DotOpConversionHelper helper;
triton::DotOp op;
DotOpAdaptor adapter;
ConversionPatternRewriter &rewriter;
TypeConverter *typeConverter;
Location loc;
MLIRContext *ctx{};
auto mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
auto wpt = mmaLayout.getWarpsPerCTA();
MMA16816ConversionHelper(triton::DotOp op, Value thread, DotOpAdaptor adapter,
ConversionPatternRewriter &rewriter,
TypeConverter *typeConverter, Location loc)
: helper(op), op(op), adapter(adapter), rewriter(rewriter),
typeConverter(typeConverter), loc(loc), ctx(op.getContext()),
thread(thread) {
A = op.a();
B = op.b();
C = op.c();
D = op.getResult();
// TODO(Superjomn) Process C->is_trans_a() logic
aTensorTy = A.getType().cast<RankedTensorType>();
bTensorTy = B.getType().cast<RankedTensorType>();
dTensorTy = D.getType().cast<RankedTensorType>();
DotOpConversionHelper helper(op);
aShape = aTensorTy.getShape();
bShape = bTensorTy.getShape();
dShape = dTensorTy.getShape();
int NK = aShape[1];
mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
auto mmaInstrShape = helper.getMmaInstrShape();
const int mmaInstrM = mmaInstrShape[0];
const int mmaInstrN = mmaInstrShape[1];
const int mmaInstrK = mmaInstrShape[2];
wpt = mmaLayout.getWarpsPerCTA();
auto matShape = helper.getMmaMatShape();
const int matShapeM = matShape[0];
const int matShapeN = matShape[1];
const int matShapeK = matShape[2];
auto mmaInstrShape = helper.getMmaInstrShape();
mmaInstrM = mmaInstrShape[0];
mmaInstrN = mmaInstrShape[1];
mmaInstrK = mmaInstrShape[2];
// shape / shape_per_cta
const int numRepM = std::max<int>(dShape[0] / (wpt[0] * mmaInstrM), 1);
const int numRepN = std::max<int>(dShape[1] / (wpt[1] * mmaInstrN), 1);
const int numRepK = std::max<int>(NK / mmaInstrK, 1);
auto matShape = helper.getMmaMatShape();
matShapeM = matShape[0];
matShapeN = matShape[1];
matShapeK = matShape[2];
Value _32 = i32_val(32);
Value thread = getThreadId(rewriter, loc);
Value lane = urem(thread, _32);
Value warp = udiv(thread, _32);
Value warpMN = udiv(warp, i32_val(wpt[0]));
Value warpM = urem(warp, i32_val(wpt[0]));
Value warpN = urem(warpMN, i32_val(wpt[1]));
int NK = aShape[1];
// shape / shape_per_cta
numRepM = std::max<int>(dShape[0] / (wpt[0] * mmaInstrM), 1);
numRepN = std::max<int>(dShape[1] / (wpt[1] * mmaInstrN), 1);
numRepK = std::max<int>(NK / mmaInstrK, 1);
size_t aElemBytes = aTensorTy.getElementTypeBitWidth() / 8;
size_t bElemBytes = bTensorTy.getElementTypeBitWidth() / 8;
Value _32 = i32_val(32);
lane = urem(thread, _32);
warp = udiv(thread, _32);
warpMN = udiv(warp, i32_val(wpt[0]));
warpM = urem(warp, i32_val(wpt[0]));
warpN = urem(warpMN, i32_val(wpt[1]));
std::map<std::pair<unsigned, unsigned>, Value> ha;
std::map<std::pair<unsigned, unsigned>, Value> hb;
aElemBytes = aTensorTy.getElementTypeBitWidth() / 8;
bElemBytes = bTensorTy.getElementTypeBitWidth() / 8;
}
// the original register_lds2, but discard the prefetch logic.
auto ld2 = [](decltype(ha) &vals, int mn, int k, Value val) {
vals[{mn, k}] = val;
};
// Loading $a from smem to registers, returns a LLVM::Struct.
Value loadA() {
ValueTable ha;
std::function<void(int, int)> loadFn;
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
// load from smem
loadFn = getLoadMatrixFn(
A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
// load from registers, used in gemm fuse
// TODO(Superjomn) Port the logic.
assert(false && "Loading A from register is not supported yet.");
} else {
assert(false && "A's layout is not supported.");
}
// step1. Perform loading.
for (unsigned m = 0; m < numRepM; ++m)
for (unsigned k = 0; k < numRepK; ++k)
loadFn(2 * m, 2 * k);
// step2. Format the values to LLVM::Struct to passing to mma codegen.
Value result = composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
// TODO[Superjomn]: Replace the convert_layout op with the result once the
// DotOperandEncodingAttr is ready.
return result;
}
// Loading $b from smem to registers, returns a LLVM::Struct.
Value loadB() {
ValueTable hb;
auto loadFn = getLoadMatrixFn(
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
for (unsigned n = 0; n < std::max(numRepN / 2, 1); ++n) {
for (unsigned k = 0; k < numRepK; ++k)
loadFn(2 * n, 2 * k);
}
Value result = composeValuesToDotOperandLayoutStruct(
hb, std::max(numRepN / 2, 1), numRepK);
return result;
}
// Loading $c from smem(?) to registers, returns a Value.
// NOTE Only SplatLike tensor is supported now.
Value loadC() {
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
// shared layout or blocked layout, we will support them by expanding
// convert_layout.
auto hc = helper.loadSplatLikeC(C, loc, rewriter);
assert(hc.size() == 4UL && "Only splat-like C is supported now");
return hc[0];
}
// Conduct the Dot conversion.
// Input the \param a, \param b, \param c, all of them are result of loading.
LogicalResult convertDot(Value a, Value b, Value c) {
ValueTable ha = getValuesFromDotOperandLayoutStruct(a, numRepM, numRepK);
ValueTable hb = getValuesFromDotOperandLayoutStruct(
b, std::max(numRepN / 2, 1), numRepK);
const int fcSize = 4 * numRepM * numRepN;
SmallVector<Value> fc(fcSize, c);
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
unsigned colsPerThread = numRepN * 2;
PTXBuilder builder;
auto &mma = *builder.create(helper.getMmaInstr().str());
auto retArgs = builder.newListOperand(4, "=r");
auto aArgs = builder.newListOperand({
{ha[{m, k}], "r"},
{ha[{m + 1, k}], "r"},
{ha[{m, k + 1}], "r"},
{ha[{m + 1, k + 1}], "r"},
});
auto bArgs =
builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}});
auto cArgs = builder.newListOperand();
for (int i = 0; i < 4; ++i) {
cArgs->listAppend(builder.newOperand(fc[m * colsPerThread + 4 * n + i],
std::to_string(i)));
// reuse the output registers
}
mma(retArgs, aArgs, bArgs, cArgs);
Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType());
auto getIntAttr = [&](int v) {
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
};
for (int i = 0; i < 4; i++)
fc[m * colsPerThread + 4 * n + i] =
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
};
for (unsigned k = 0; k < numRepK; ++k)
for (unsigned m = 0; m < numRepM; ++m)
for (unsigned n = 0; n < numRepN; ++n)
callMma(2 * m, n, 2 * k);
// replace with new packed result
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fc.size(), type::f32Ty(ctx)));
Value res = getStructFromElements(loc, fc, rewriter, structTy);
rewriter.replaceOp(op, res);
return success();
}
private:
std::function<void(int, int)>
getLoadMatrixFn(Value tensor, Value llTensor, int wpt, int kOrder,
ArrayRef<int> instrShape, ArrayRef<int> matShape,
Value warpId, ValueTable &vals) {
// Load A or B matrix.
auto getLoadMatrixFn =
[&](Value tensor, Value llTensor, int wpt, int kOrder,
ArrayRef<int> instrShape, ArrayRef<int> matShape, Value warpId,
decltype(ha) &vals) -> std::function<void(int, int)> {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
// We assumes that the input operand of Dot should be from shared layout.
// TODO(Superjomn) Consider other layouts if needed later.
@@ -2739,25 +2867,31 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
const int elemBytes = tensorTy.getElementTypeBitWidth() / 8;
auto order = sharedLayout.getOrder();
MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder,
tensorTy.getShape() /*tileShape*/, instrShape,
matShape, perPhase, maxPhase, elemBytes, rewriter,
typeConverter, loc);
SmallVector<Value> offs = loader.computeOffsets(warpId, lane);
const int numPtrs = loader.getNumPtr();
SmallVector<Value> ptrs(numPtrs);
Type smemPtrTy = helper.getShemPtrTy();
for (int i = 0; i < numPtrs; ++i) {
ptrs[i] =
bitcast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
}
bool needTrans = kOrder != order[0];
// the original register_lds2, but discard the prefetch logic.
auto ld2 = [](ValueTable &vals, int mn, int k, Value val) {
vals[{mn, k}] = val;
};
// (a, b) is the coordinate.
auto load = [=, &vals, &helper, &ld2](int a, int b) {
auto load = [=, &vals, &ld2](int a, int b) {
MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder,
tensorTy.getShape() /*tileShape*/, instrShape,
matShape, perPhase, maxPhase, elemBytes,
rewriter, typeConverter, loc);
SmallVector<Value> offs = loader.computeOffsets(warpId, lane);
const int numPtrs = loader.getNumPtr();
SmallVector<Value> ptrs(numPtrs);
Type smemPtrTy = helper.getShemPtrTy();
for (int i = 0; i < numPtrs; ++i) {
ptrs[i] =
bitcast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
}
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
ptrs, helper.getMatType(), helper.getShemPtrTy());
@@ -2775,89 +2909,74 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
};
return load;
};
std::function<void(int, int)> loadA;
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
// load from smem
loadA = getLoadMatrixFn(
A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
// load from registers, used in gemm fuse
// TODO(Superjomn) Port the logic.
assert(false && "Loading A from register is not supported yet.");
} else {
assert(false && "A's layout is not supported.");
}
std::function<void(int, int)> loadB = getLoadMatrixFn(
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
const int fcSize = 4 * numRepM * numRepN;
SmallVector<Value> fc(fcSize);
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
// shared layout or blocked layout, we will support them by expanding
// convert_layout.
auto hc = helper.loadSplatLikeC(C, loc, rewriter);
assert(hc.size() == 4UL && "Only splat-like C is supported now");
for (int i = 0; i < fc.size(); i++)
fc[i] = hc[0];
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
unsigned colsPerThread = numRepN * 2;
PTXBuilder builder;
auto &mma = *builder.create(helper.getMmaInstr().str());
auto retArgs = builder.newListOperand(4, "=r");
auto aArgs = builder.newListOperand({
{ha[{m, k}], "r"},
{ha[{m + 1, k}], "r"},
{ha[{m, k + 1}], "r"},
{ha[{m + 1, k + 1}], "r"},
});
auto bArgs =
builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}});
auto cArgs = builder.newListOperand();
for (int i = 0; i < 4; ++i) {
cArgs->listAppend(builder.newOperand(fc[m * colsPerThread + 4 * n + i],
std::to_string(i)));
// reuse the output registers
}
mma(retArgs, aArgs, bArgs, cArgs);
Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType());
auto getIntAttr = [&](int v) {
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
};
for (int i = 0; i < 4; i++)
fc[m * colsPerThread + 4 * n + i] =
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
};
// Main program
for (unsigned k = 0; k < numRepK; ++k) {
for (unsigned m = 0; m < numRepM; ++m)
loadA(2 * m, 2 * k);
for (unsigned n = 0; n < numRepN; n += 2)
loadB(n, 2 * k);
for (unsigned m = 0; m < numRepM; ++m)
for (unsigned n = 0; n < numRepN; ++n) {
callMma(2 * m, n, 2 * k);
// Compose a map of Values to a LLVM::Struct.
// The layout is a list of Value with coordinate of (i,j), the order is as
// the follows:
// [
// (0,0), (0,1), (1,0), (1,1), # i=0, j=0
// (0,2), (0,3), (1,2), (1,3), # i=0, j=1
// (0,4), (0,5), (1,4), (1,5), # i=0, j=2
// ...
// (2,0), (2,1), (3,0), (3,1), # i=1, j=0
// (2,2), (2,3), (3,2), (3,3), # i=1, j=1
// (2,4), (2,5), (2,4), (2,5), # i=1, j=2
// ...
// ]
// i \in [0, n0) and j \in [0, n1)
// There should be \param n0 * \param n1 elements in the output Struct.
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
int n1) {
std::vector<Value> elems;
for (unsigned m = 0; m < n0; ++m)
for (unsigned k = 0; k < n1; ++k) {
elems.push_back(vals.at({2 * m, 2 * k}));
elems.push_back(vals.at({2 * m, 2 * k + 1}));
elems.push_back(vals.at({2 * m + 1, 2 * k}));
elems.push_back(vals.at({2 * m + 1, 2 * k + 1}));
}
assert(!elems.empty());
Type fp16Ty = aTensorTy.getElementType();
Type fp16x2Ty = vec_ty(fp16Ty, 2);
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems.size(), fp16x2Ty));
auto result = getStructFromElements(loc, elems, rewriter, structTy);
return result;
}
// replace with new packed result
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fc.size(), type::f32Ty(ctx)));
Value res = getStructFromElements(loc, fc, rewriter, structTy);
rewriter.replaceOp(op, res);
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1) {
auto elems = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct(
loc, value, rewriter);
return success();
int offset{};
ValueTable vals;
for (int i = 0; i < n0; i++) {
for (int j = 0; j < n1; j++) {
vals[{2 * i, 2 * j}] = elems[offset++];
vals[{2 * i, 2 * j + 1}] = elems[offset++];
vals[{2 * i + 1, 2 * j}] = elems[offset++];
vals[{2 * i + 1, 2 * j + 1}] = elems[offset++];
}
}
return vals;
}
};
LogicalResult
DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
MMA16816ConversionHelper mmaHelper(op, getThreadId(rewriter, loc), adapter,
rewriter, getTypeConverter(), loc);
auto A = mmaHelper.loadA();
auto B = mmaHelper.loadB();
auto C = mmaHelper.loadC();
return mmaHelper.convertDot(A, B, C);
}
/// ====================== mma codegen end ============================
@@ -3012,9 +3131,9 @@ struct InsertSliceAsyncOpConversion
auto inOrder = srcBlockedLayout.getOrder();
auto outOrder = resSharedLayout.getOrder();
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over elements
// across phases.
// If perPhase * maxPhase == threadsPerCTA, swizzle is not allowd
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
// elements across phases. If perPhase * maxPhase == threadsPerCTA,
// swizzle is not allowd
auto numSwizzleRows = std::max<unsigned>(
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
// A sharedLayout encoding has a "vec" parameter.