[triton-mlir][BACKEND] decouple loading from mma codegen in dot conversion (#764)
This PR decouples the operand loading from the mma codegen to make it ready for the ongoing `DotOperandEncodingAttr` migration. The existing DotOp conversion is composed of the following two procedures: 1. Loading the $a,$b,$c operand from smem to registers 2. Conducting the MMA instruction codegen. While in the latest design, the 1st stage should be part of the `convert_layout(shared_layout) -> dot_operand_layout`, that's why the decoupling is necessary. Some details, this PR introduces a `MMA16816ConversionHelper` class, it has `loadA`, `loadB` and `loadC` methods to help load $a, $b and $c from smem to registers, both `loadA` and `loadB` methods returns a `LLVM::Struct` which should be compatible with the new `DotOperandEncodingAttr` conversion. The conversion layout for $a and $b is as follows: ```c++ // The layout is a list of Value with coordinate of (i,j), the order is as // the follows: // [ // (0,0), (0,1), (1,0), (1,1), # i=0, j=0 // (0,2), (0,3), (1,2), (1,3), # i=0, j=1 // (0,4), (0,5), (1,4), (1,5), # i=0, j=2 // ... // (2,0), (2,1), (3,0), (3,1), # i=1, j=0 // (2,2), (2,3), (3,2), (3,3), # i=1, j=1 // (2,4), (2,5), (2,4), (2,5), # i=1, j=2 // ... // ] // i \in [0, n0) and j \in [0, n1) ``` In the `convertDot` method, it takes loaded $a, $b, $c($a and $b are type of `LLVM::Struct` while $c is a scalar Value), extract the elements from `LLVM::Struct` following the layout above, and pass the elements to MMA inline asm.
This commit is contained in:
@@ -337,12 +337,13 @@ static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||
|
||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||
static SmallVector<Value>
|
||||
getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems,
|
||||
getElementsFromStruct(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
SmallVector<Value> results(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
Type type =
|
||||
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody()[i];
|
||||
ArrayRef<Type> types =
|
||||
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
|
||||
SmallVector<Value> results(types.size());
|
||||
for (unsigned i = 0; i < types.size(); ++i) {
|
||||
Type type = types[i];
|
||||
results[i] = extract_val(type, llvmStruct, rewriter.getI64ArrayAttr(i));
|
||||
}
|
||||
return results;
|
||||
@@ -715,8 +716,7 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
|
||||
auto shape = value.getType().cast<RankedTensorType>().getShape();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
unsigned valueElems = layout.getElemsPerThread(shape);
|
||||
auto valueVals = getElementsFromStruct(loc, llValue, valueElems, rewriter);
|
||||
auto valueVals = getElementsFromStruct(loc, llValue, rewriter);
|
||||
return valueVals;
|
||||
}
|
||||
|
||||
@@ -977,7 +977,7 @@ struct BroadcastOpConversion
|
||||
|
||||
unsigned srcElems = srcLayout.getElemsPerThread(srcShape);
|
||||
auto elemTy = resultTy.getElementType();
|
||||
auto srcVals = getElementsFromStruct(loc, src, srcElems, rewriter);
|
||||
auto srcVals = getElementsFromStruct(loc, src, rewriter);
|
||||
unsigned resultElems = resultLayout.getElemsPerThread(resultShape);
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (unsigned i = 0; i < srcElems; ++i) {
|
||||
@@ -1025,8 +1025,7 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
auto vals =
|
||||
this->getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
|
||||
auto vals = this->getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||
Value view = getStructFromElements(loc, vals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
@@ -1308,9 +1307,8 @@ struct AddPtrOpConversion
|
||||
getTypeConverter()->convertType(resultTensorTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
|
||||
auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), elems, rewriter);
|
||||
auto offsets =
|
||||
getElementsFromStruct(loc, adaptor.offset(), elems, rewriter);
|
||||
auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), rewriter);
|
||||
auto offsets = getElementsFromStruct(loc, adaptor.offset(), rewriter);
|
||||
SmallVector<Value> resultVals(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
resultVals[i] = gep(elemTy, ptrs[i], offsets[i]);
|
||||
@@ -1416,9 +1414,9 @@ public:
|
||||
|
||||
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
||||
auto lhss = this->getElementsFromStruct(loc, concreteThis->getLhs(adaptor),
|
||||
elems, rewriter);
|
||||
rewriter);
|
||||
auto rhss = this->getElementsFromStruct(loc, concreteThis->getRhs(adaptor),
|
||||
elems, rewriter);
|
||||
rewriter);
|
||||
SmallVector<Value> resultVals(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
|
||||
@@ -1761,7 +1759,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
||||
// Potentially we need to store for multiple CTAs in this replication
|
||||
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
||||
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||
unsigned inVec = 0;
|
||||
unsigned outVec = 0;
|
||||
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
|
||||
@@ -1824,7 +1822,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
unsigned perPhase = dstSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
||||
unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape);
|
||||
auto inVals = getElementsFromStruct(loc, adaptor.src(), numElems, rewriter);
|
||||
auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||
unsigned srcAccumSizeInThreads =
|
||||
product<unsigned>(srcBlockedLayout.getSizePerThread());
|
||||
auto elemTy = srcTy.getElementType();
|
||||
@@ -2661,75 +2659,205 @@ private:
|
||||
DotOp dot;
|
||||
};
|
||||
|
||||
LogicalResult
|
||||
DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *ctx = op->getContext();
|
||||
// D = A * B + C
|
||||
Value A = op.a();
|
||||
Value B = op.b();
|
||||
Value C = op.c();
|
||||
Value D = op.getResult();
|
||||
bool allowTF32 = op.allowTF32();
|
||||
// This class helps to adapt the existing DotOpConversion to the latest
|
||||
// DotOpOperand layout design. It decouples the exising implementation to two
|
||||
// parts:
|
||||
// 1. loading the specific operand matrix(for $a, $b, $c) from smem
|
||||
// 2. passing the loaded value and perform the mma codegen
|
||||
struct MMA16816ConversionHelper {
|
||||
Value A, B, C, D;
|
||||
RankedTensorType aTensorTy, bTensorTy, dTensorTy;
|
||||
ArrayRef<int64_t> aShape, bShape, dShape;
|
||||
MmaEncodingAttr mmaLayout;
|
||||
ArrayRef<unsigned int> wpt;
|
||||
|
||||
auto aTensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = B.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = D.getType().cast<RankedTensorType>();
|
||||
int mmaInstrM{-1}, mmaInstrN{-1}, mmaInstrK{-1};
|
||||
int matShapeM{-1}, matShapeN{-1}, matShapeK{-1};
|
||||
int numRepM{-1}, numRepN{-1}, numRepK{-1};
|
||||
Value thread, lane, warp, warpMN, warpN, warpM;
|
||||
size_t aElemBytes{}, bElemBytes{};
|
||||
|
||||
auto aShape = aTensorTy.getShape();
|
||||
auto bShape = bTensorTy.getShape();
|
||||
auto dShape = dTensorTy.getShape();
|
||||
DotOpConversionHelper helper;
|
||||
triton::DotOp op;
|
||||
DotOpAdaptor adapter;
|
||||
ConversionPatternRewriter &rewriter;
|
||||
TypeConverter *typeConverter;
|
||||
Location loc;
|
||||
MLIRContext *ctx{};
|
||||
|
||||
auto mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
|
||||
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
MMA16816ConversionHelper(triton::DotOp op, Value thread, DotOpAdaptor adapter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
TypeConverter *typeConverter, Location loc)
|
||||
: helper(op), op(op), adapter(adapter), rewriter(rewriter),
|
||||
typeConverter(typeConverter), loc(loc), ctx(op.getContext()),
|
||||
thread(thread) {
|
||||
A = op.a();
|
||||
B = op.b();
|
||||
C = op.c();
|
||||
D = op.getResult();
|
||||
|
||||
// TODO(Superjomn) Process C->is_trans_a() logic
|
||||
aTensorTy = A.getType().cast<RankedTensorType>();
|
||||
bTensorTy = B.getType().cast<RankedTensorType>();
|
||||
dTensorTy = D.getType().cast<RankedTensorType>();
|
||||
|
||||
DotOpConversionHelper helper(op);
|
||||
aShape = aTensorTy.getShape();
|
||||
bShape = bTensorTy.getShape();
|
||||
dShape = dTensorTy.getShape();
|
||||
|
||||
int NK = aShape[1];
|
||||
mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
|
||||
auto mmaInstrShape = helper.getMmaInstrShape();
|
||||
const int mmaInstrM = mmaInstrShape[0];
|
||||
const int mmaInstrN = mmaInstrShape[1];
|
||||
const int mmaInstrK = mmaInstrShape[2];
|
||||
wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
auto matShape = helper.getMmaMatShape();
|
||||
const int matShapeM = matShape[0];
|
||||
const int matShapeN = matShape[1];
|
||||
const int matShapeK = matShape[2];
|
||||
auto mmaInstrShape = helper.getMmaInstrShape();
|
||||
mmaInstrM = mmaInstrShape[0];
|
||||
mmaInstrN = mmaInstrShape[1];
|
||||
mmaInstrK = mmaInstrShape[2];
|
||||
|
||||
// shape / shape_per_cta
|
||||
const int numRepM = std::max<int>(dShape[0] / (wpt[0] * mmaInstrM), 1);
|
||||
const int numRepN = std::max<int>(dShape[1] / (wpt[1] * mmaInstrN), 1);
|
||||
const int numRepK = std::max<int>(NK / mmaInstrK, 1);
|
||||
auto matShape = helper.getMmaMatShape();
|
||||
matShapeM = matShape[0];
|
||||
matShapeN = matShape[1];
|
||||
matShapeK = matShape[2];
|
||||
|
||||
Value _32 = i32_val(32);
|
||||
Value thread = getThreadId(rewriter, loc);
|
||||
Value lane = urem(thread, _32);
|
||||
Value warp = udiv(thread, _32);
|
||||
Value warpMN = udiv(warp, i32_val(wpt[0]));
|
||||
Value warpM = urem(warp, i32_val(wpt[0]));
|
||||
Value warpN = urem(warpMN, i32_val(wpt[1]));
|
||||
int NK = aShape[1];
|
||||
// shape / shape_per_cta
|
||||
numRepM = std::max<int>(dShape[0] / (wpt[0] * mmaInstrM), 1);
|
||||
numRepN = std::max<int>(dShape[1] / (wpt[1] * mmaInstrN), 1);
|
||||
numRepK = std::max<int>(NK / mmaInstrK, 1);
|
||||
|
||||
size_t aElemBytes = aTensorTy.getElementTypeBitWidth() / 8;
|
||||
size_t bElemBytes = bTensorTy.getElementTypeBitWidth() / 8;
|
||||
Value _32 = i32_val(32);
|
||||
lane = urem(thread, _32);
|
||||
warp = udiv(thread, _32);
|
||||
warpMN = udiv(warp, i32_val(wpt[0]));
|
||||
warpM = urem(warp, i32_val(wpt[0]));
|
||||
warpN = urem(warpMN, i32_val(wpt[1]));
|
||||
|
||||
std::map<std::pair<unsigned, unsigned>, Value> ha;
|
||||
std::map<std::pair<unsigned, unsigned>, Value> hb;
|
||||
aElemBytes = aTensorTy.getElementTypeBitWidth() / 8;
|
||||
bElemBytes = bTensorTy.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
// the original register_lds2, but discard the prefetch logic.
|
||||
auto ld2 = [](decltype(ha) &vals, int mn, int k, Value val) {
|
||||
vals[{mn, k}] = val;
|
||||
};
|
||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||
Value loadA() {
|
||||
ValueTable ha;
|
||||
std::function<void(int, int)> loadFn;
|
||||
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
||||
// load from smem
|
||||
loadFn = getLoadMatrixFn(
|
||||
A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
||||
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
// load from registers, used in gemm fuse
|
||||
// TODO(Superjomn) Port the logic.
|
||||
assert(false && "Loading A from register is not supported yet.");
|
||||
} else {
|
||||
assert(false && "A's layout is not supported.");
|
||||
}
|
||||
|
||||
// step1. Perform loading.
|
||||
for (unsigned m = 0; m < numRepM; ++m)
|
||||
for (unsigned k = 0; k < numRepK; ++k)
|
||||
loadFn(2 * m, 2 * k);
|
||||
|
||||
// step2. Format the values to LLVM::Struct to passing to mma codegen.
|
||||
Value result = composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
|
||||
|
||||
// TODO[Superjomn]: Replace the convert_layout op with the result once the
|
||||
// DotOperandEncodingAttr is ready.
|
||||
return result;
|
||||
}
|
||||
|
||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||
Value loadB() {
|
||||
ValueTable hb;
|
||||
auto loadFn = getLoadMatrixFn(
|
||||
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||
|
||||
for (unsigned n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
||||
for (unsigned k = 0; k < numRepK; ++k)
|
||||
loadFn(2 * n, 2 * k);
|
||||
}
|
||||
|
||||
Value result = composeValuesToDotOperandLayoutStruct(
|
||||
hb, std::max(numRepN / 2, 1), numRepK);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Loading $c from smem(?) to registers, returns a Value.
|
||||
// NOTE Only SplatLike tensor is supported now.
|
||||
Value loadC() {
|
||||
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
|
||||
// shared layout or blocked layout, we will support them by expanding
|
||||
// convert_layout.
|
||||
auto hc = helper.loadSplatLikeC(C, loc, rewriter);
|
||||
assert(hc.size() == 4UL && "Only splat-like C is supported now");
|
||||
return hc[0];
|
||||
}
|
||||
|
||||
// Conduct the Dot conversion.
|
||||
// Input the \param a, \param b, \param c, all of them are result of loading.
|
||||
LogicalResult convertDot(Value a, Value b, Value c) {
|
||||
ValueTable ha = getValuesFromDotOperandLayoutStruct(a, numRepM, numRepK);
|
||||
ValueTable hb = getValuesFromDotOperandLayoutStruct(
|
||||
b, std::max(numRepN / 2, 1), numRepK);
|
||||
|
||||
const int fcSize = 4 * numRepM * numRepN;
|
||||
SmallVector<Value> fc(fcSize, c);
|
||||
|
||||
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
|
||||
unsigned colsPerThread = numRepN * 2;
|
||||
PTXBuilder builder;
|
||||
auto &mma = *builder.create(helper.getMmaInstr().str());
|
||||
auto retArgs = builder.newListOperand(4, "=r");
|
||||
auto aArgs = builder.newListOperand({
|
||||
{ha[{m, k}], "r"},
|
||||
{ha[{m + 1, k}], "r"},
|
||||
{ha[{m, k + 1}], "r"},
|
||||
{ha[{m + 1, k + 1}], "r"},
|
||||
});
|
||||
auto bArgs =
|
||||
builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}});
|
||||
auto cArgs = builder.newListOperand();
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
cArgs->listAppend(builder.newOperand(fc[m * colsPerThread + 4 * n + i],
|
||||
std::to_string(i)));
|
||||
// reuse the output registers
|
||||
}
|
||||
mma(retArgs, aArgs, bArgs, cArgs);
|
||||
Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType());
|
||||
|
||||
auto getIntAttr = [&](int v) {
|
||||
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
|
||||
};
|
||||
|
||||
for (int i = 0; i < 4; i++)
|
||||
fc[m * colsPerThread + 4 * n + i] =
|
||||
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
|
||||
};
|
||||
|
||||
for (unsigned k = 0; k < numRepK; ++k)
|
||||
for (unsigned m = 0; m < numRepM; ++m)
|
||||
for (unsigned n = 0; n < numRepN; ++n)
|
||||
callMma(2 * m, n, 2 * k);
|
||||
|
||||
// replace with new packed result
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(fc.size(), type::f32Ty(ctx)));
|
||||
Value res = getStructFromElements(loc, fc, rewriter, structTy);
|
||||
rewriter.replaceOp(op, res);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<void(int, int)>
|
||||
getLoadMatrixFn(Value tensor, Value llTensor, int wpt, int kOrder,
|
||||
ArrayRef<int> instrShape, ArrayRef<int> matShape,
|
||||
Value warpId, ValueTable &vals) {
|
||||
|
||||
// Load A or B matrix.
|
||||
auto getLoadMatrixFn =
|
||||
[&](Value tensor, Value llTensor, int wpt, int kOrder,
|
||||
ArrayRef<int> instrShape, ArrayRef<int> matShape, Value warpId,
|
||||
decltype(ha) &vals) -> std::function<void(int, int)> {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
// We assumes that the input operand of Dot should be from shared layout.
|
||||
// TODO(Superjomn) Consider other layouts if needed later.
|
||||
@@ -2739,25 +2867,31 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
||||
const int elemBytes = tensorTy.getElementTypeBitWidth() / 8;
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder,
|
||||
tensorTy.getShape() /*tileShape*/, instrShape,
|
||||
matShape, perPhase, maxPhase, elemBytes, rewriter,
|
||||
typeConverter, loc);
|
||||
SmallVector<Value> offs = loader.computeOffsets(warpId, lane);
|
||||
|
||||
const int numPtrs = loader.getNumPtr();
|
||||
SmallVector<Value> ptrs(numPtrs);
|
||||
|
||||
Type smemPtrTy = helper.getShemPtrTy();
|
||||
for (int i = 0; i < numPtrs; ++i) {
|
||||
ptrs[i] =
|
||||
bitcast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
|
||||
}
|
||||
|
||||
bool needTrans = kOrder != order[0];
|
||||
|
||||
// the original register_lds2, but discard the prefetch logic.
|
||||
auto ld2 = [](ValueTable &vals, int mn, int k, Value val) {
|
||||
vals[{mn, k}] = val;
|
||||
};
|
||||
|
||||
// (a, b) is the coordinate.
|
||||
auto load = [=, &vals, &helper, &ld2](int a, int b) {
|
||||
auto load = [=, &vals, &ld2](int a, int b) {
|
||||
MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder,
|
||||
tensorTy.getShape() /*tileShape*/, instrShape,
|
||||
matShape, perPhase, maxPhase, elemBytes,
|
||||
rewriter, typeConverter, loc);
|
||||
SmallVector<Value> offs = loader.computeOffsets(warpId, lane);
|
||||
|
||||
const int numPtrs = loader.getNumPtr();
|
||||
|
||||
SmallVector<Value> ptrs(numPtrs);
|
||||
|
||||
Type smemPtrTy = helper.getShemPtrTy();
|
||||
for (int i = 0; i < numPtrs; ++i) {
|
||||
ptrs[i] =
|
||||
bitcast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
|
||||
}
|
||||
|
||||
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
|
||||
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
|
||||
ptrs, helper.getMatType(), helper.getShemPtrTy());
|
||||
@@ -2775,89 +2909,74 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
||||
};
|
||||
|
||||
return load;
|
||||
};
|
||||
|
||||
std::function<void(int, int)> loadA;
|
||||
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
||||
// load from smem
|
||||
loadA = getLoadMatrixFn(
|
||||
A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
||||
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
// load from registers, used in gemm fuse
|
||||
// TODO(Superjomn) Port the logic.
|
||||
assert(false && "Loading A from register is not supported yet.");
|
||||
} else {
|
||||
assert(false && "A's layout is not supported.");
|
||||
}
|
||||
|
||||
std::function<void(int, int)> loadB = getLoadMatrixFn(
|
||||
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||
|
||||
const int fcSize = 4 * numRepM * numRepN;
|
||||
SmallVector<Value> fc(fcSize);
|
||||
|
||||
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
|
||||
// shared layout or blocked layout, we will support them by expanding
|
||||
// convert_layout.
|
||||
auto hc = helper.loadSplatLikeC(C, loc, rewriter);
|
||||
assert(hc.size() == 4UL && "Only splat-like C is supported now");
|
||||
for (int i = 0; i < fc.size(); i++)
|
||||
fc[i] = hc[0];
|
||||
|
||||
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
|
||||
unsigned colsPerThread = numRepN * 2;
|
||||
PTXBuilder builder;
|
||||
auto &mma = *builder.create(helper.getMmaInstr().str());
|
||||
auto retArgs = builder.newListOperand(4, "=r");
|
||||
auto aArgs = builder.newListOperand({
|
||||
{ha[{m, k}], "r"},
|
||||
{ha[{m + 1, k}], "r"},
|
||||
{ha[{m, k + 1}], "r"},
|
||||
{ha[{m + 1, k + 1}], "r"},
|
||||
});
|
||||
auto bArgs =
|
||||
builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}});
|
||||
auto cArgs = builder.newListOperand();
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
cArgs->listAppend(builder.newOperand(fc[m * colsPerThread + 4 * n + i],
|
||||
std::to_string(i)));
|
||||
// reuse the output registers
|
||||
}
|
||||
mma(retArgs, aArgs, bArgs, cArgs);
|
||||
Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType());
|
||||
|
||||
auto getIntAttr = [&](int v) {
|
||||
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
|
||||
};
|
||||
|
||||
for (int i = 0; i < 4; i++)
|
||||
fc[m * colsPerThread + 4 * n + i] =
|
||||
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
|
||||
};
|
||||
|
||||
// Main program
|
||||
for (unsigned k = 0; k < numRepK; ++k) {
|
||||
for (unsigned m = 0; m < numRepM; ++m)
|
||||
loadA(2 * m, 2 * k);
|
||||
for (unsigned n = 0; n < numRepN; n += 2)
|
||||
loadB(n, 2 * k);
|
||||
for (unsigned m = 0; m < numRepM; ++m)
|
||||
for (unsigned n = 0; n < numRepN; ++n) {
|
||||
callMma(2 * m, n, 2 * k);
|
||||
// Compose a map of Values to a LLVM::Struct.
|
||||
// The layout is a list of Value with coordinate of (i,j), the order is as
|
||||
// the follows:
|
||||
// [
|
||||
// (0,0), (0,1), (1,0), (1,1), # i=0, j=0
|
||||
// (0,2), (0,3), (1,2), (1,3), # i=0, j=1
|
||||
// (0,4), (0,5), (1,4), (1,5), # i=0, j=2
|
||||
// ...
|
||||
// (2,0), (2,1), (3,0), (3,1), # i=1, j=0
|
||||
// (2,2), (2,3), (3,2), (3,3), # i=1, j=1
|
||||
// (2,4), (2,5), (2,4), (2,5), # i=1, j=2
|
||||
// ...
|
||||
// ]
|
||||
// i \in [0, n0) and j \in [0, n1)
|
||||
// There should be \param n0 * \param n1 elements in the output Struct.
|
||||
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
|
||||
int n1) {
|
||||
std::vector<Value> elems;
|
||||
for (unsigned m = 0; m < n0; ++m)
|
||||
for (unsigned k = 0; k < n1; ++k) {
|
||||
elems.push_back(vals.at({2 * m, 2 * k}));
|
||||
elems.push_back(vals.at({2 * m, 2 * k + 1}));
|
||||
elems.push_back(vals.at({2 * m + 1, 2 * k}));
|
||||
elems.push_back(vals.at({2 * m + 1, 2 * k + 1}));
|
||||
}
|
||||
|
||||
assert(!elems.empty());
|
||||
|
||||
Type fp16Ty = aTensorTy.getElementType();
|
||||
Type fp16x2Ty = vec_ty(fp16Ty, 2);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems.size(), fp16x2Ty));
|
||||
auto result = getStructFromElements(loc, elems, rewriter, structTy);
|
||||
return result;
|
||||
}
|
||||
|
||||
// replace with new packed result
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(fc.size(), type::f32Ty(ctx)));
|
||||
Value res = getStructFromElements(loc, fc, rewriter, structTy);
|
||||
rewriter.replaceOp(op, res);
|
||||
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1) {
|
||||
auto elems = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct(
|
||||
loc, value, rewriter);
|
||||
|
||||
return success();
|
||||
int offset{};
|
||||
ValueTable vals;
|
||||
for (int i = 0; i < n0; i++) {
|
||||
for (int j = 0; j < n1; j++) {
|
||||
vals[{2 * i, 2 * j}] = elems[offset++];
|
||||
vals[{2 * i, 2 * j + 1}] = elems[offset++];
|
||||
vals[{2 * i + 1, 2 * j}] = elems[offset++];
|
||||
vals[{2 * i + 1, 2 * j + 1}] = elems[offset++];
|
||||
}
|
||||
}
|
||||
return vals;
|
||||
}
|
||||
};
|
||||
|
||||
LogicalResult
|
||||
DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
MMA16816ConversionHelper mmaHelper(op, getThreadId(rewriter, loc), adapter,
|
||||
rewriter, getTypeConverter(), loc);
|
||||
|
||||
auto A = mmaHelper.loadA();
|
||||
auto B = mmaHelper.loadB();
|
||||
auto C = mmaHelper.loadC();
|
||||
|
||||
return mmaHelper.convertDot(A, B, C);
|
||||
}
|
||||
|
||||
/// ====================== mma codegen end ============================
|
||||
@@ -3012,9 +3131,9 @@ struct InsertSliceAsyncOpConversion
|
||||
|
||||
auto inOrder = srcBlockedLayout.getOrder();
|
||||
auto outOrder = resSharedLayout.getOrder();
|
||||
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over elements
|
||||
// across phases.
|
||||
// If perPhase * maxPhase == threadsPerCTA, swizzle is not allowd
|
||||
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
|
||||
// elements across phases. If perPhase * maxPhase == threadsPerCTA,
|
||||
// swizzle is not allowd
|
||||
auto numSwizzleRows = std::max<unsigned>(
|
||||
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
|
||||
// A sharedLayout encoding has a "vec" parameter.
|
||||
|
Reference in New Issue
Block a user