[Triton-MLIR][BACKEND] Port the mma<v1> conversion (#815)

This PR does

- port the mma<v1> related code, and support dot conversion and
convert_layout[shared->dot_op<mma<v1>>]
- add a lit test for dot v1
This commit is contained in:
Yan Chunwei
2022-11-01 09:42:14 +08:00
committed by GitHub
parent cb1b87a688
commit 031c2ae77b
3 changed files with 690 additions and 92 deletions

View File

@@ -56,6 +56,12 @@ Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
IntegerAttr::get(i32ty, v));
}
Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
auto type = type::f32Ty(rewriter.getContext());
return rewriter.create<LLVM::ConstantOp>(loc, type,
rewriter.getF32FloatAttr(v));
}
// Create a index type constant.
Value createIndexConstant(OpBuilder &builder, Location loc,
@@ -90,7 +96,8 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
#define bitcast(...) rewriter.create<LLVM::BitcastOp>(loc, __VA_ARGS__)
#define bitcast(val__, type__) \
rewriter.create<LLVM::BitcastOp>(loc, type__, val__)
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
#define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__)
@@ -112,11 +119,13 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define i32_ty rewriter.getIntegerType(32)
#define f16_ty rewriter.getF16Type()
#define i8_ty rewriter.getIntegerType(8)
#define f32_ty rewriter.getF32Type()
#define vec_ty(type, num) VectorType::get(num, type)
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__)
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
// Creator for constant
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
@@ -685,7 +694,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
auto tensorTy = resType.cast<RankedTensorType>();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(srcType, constVal);
auto llSrc = bitcast(constVal, srcType);
size_t elemsPerThread = getElemsPerThread(tensorTy);
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
@@ -964,7 +973,7 @@ struct LoadOpConversion
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
v = insert_element(vecTy, v, falseVal, sVal);
}
v = bitcast(IntegerType::get(getContext(), width), v);
v = bitcast(v, IntegerType::get(getContext(), width));
PTXInstr::Operand *opr{};
if (otherIsSplatConstInt)
@@ -1001,9 +1010,8 @@ struct LoadOpConversion
} else {
curr = ret;
}
curr = bitcast(
LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
curr);
curr = bitcast(curr, LLVM::getFixedVectorType(valueElemTy,
width / valueElemNbits));
rets.push_back(curr);
}
int tmp = width / valueElemNbits;
@@ -1101,7 +1109,7 @@ struct StoreOpConversion
Value elem = valueElems[elemOffset];
if (elem.getType().isInteger(1))
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
elem = bitcast(valueElemTy, elem);
elem = bitcast(elem, valueElemTy);
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
llWord =
@@ -1109,7 +1117,7 @@ struct StoreOpConversion
rewriter.create<LLVM::ConstantOp>(
loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx)));
}
llWord = bitcast(valArgTy, llWord);
llWord = bitcast(llWord, valArgTy);
std::string constraint =
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
asmArgs.emplace_back(llWord, constraint);
@@ -1328,7 +1336,7 @@ Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
if (bits == 64) {
Type vecTy = vec_ty(f32_ty, 2);
Value vec = bitcast(vecTy, val);
Value vec = bitcast(val, vecTy);
Value val0 = extract_element(f32_ty, vec, i32_val(0));
Value val1 = extract_element(f32_ty, vec, i32_val(1));
val0 = shflSync(rewriter, loc, val0, i);
@@ -1336,7 +1344,7 @@ Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
vec = undef(vecTy);
vec = insert_element(vecTy, vec, val0, i32_val(0));
vec = insert_element(vecTy, vec, val1, i32_val(1));
return bitcast(val.getType(), vec);
return bitcast(vec, val.getType());
}
PTXBuilder builder;
@@ -1363,7 +1371,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(elemPtrTy, smemBase);
smemBase = bitcast(smemBase, elemPtrTy);
auto smemShape = getScratchConfigForReduce(op);
@@ -1430,7 +1438,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
barrier();
SmallVector<Value> resultVals(resultElems);
for (size_t i = 0; i < resultElems; i++) {
for (unsigned i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, ints[0]);
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
@@ -1469,7 +1477,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(elemPtrTy, smemBase);
smemBase = bitcast(smemBase, elemPtrTy);
auto order = srcLayout.getOrder();
unsigned sizeIntraWarps = threadsPerWarp[axis];
@@ -1569,7 +1577,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
barrier();
SmallVector<Value> resultVals(resultElems);
for (size_t i = 0; i < resultElems; i++) {
for (size_t i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, i32_val(0));
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
@@ -2136,7 +2144,7 @@ void ConvertLayoutOpConversion::processReplica(
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
ptr = bitcast(ptr_ty(vecTy, 3), ptr);
ptr = bitcast(ptr, ptr_ty(vecTy, 3));
if (stNotRd) {
Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) {
@@ -2175,7 +2183,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
smemBase = bitcast(elemPtrTy, smemBase);
smemBase = bitcast(smemBase, elemPtrTy);
auto shape = dstTy.getShape();
unsigned rank = dstTy.getRank();
SmallVector<unsigned> numReplicates(rank);
@@ -2234,7 +2242,8 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
}
SmallVector<Type> types(outElems, llvmElemTy);
Type structTy = struct_ty(getContext(), types);
auto *ctx = llvmElemTy.getContext();
Type structTy = struct_ty(types);
Value result = getStructFromElements(loc, outVals, rewriter, structTy);
rewriter.replaceOp(op, result);
@@ -2294,7 +2303,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
Value minVecVal = idx_val(minVec);
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
smemBase = bitcast(elemPtrTy, smemBase);
smemBase = bitcast(smemBase, elemPtrTy);
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
SmallVector<Value> wordVecs(numWordsEachRep);
// TODO: We should get less barriers if it is handled by membar pass
@@ -2350,7 +2359,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
// step 3: store
Value smemAddr = gep(elemPtrTy, smemBase, offset);
smemAddr = bitcast(ptr_ty(wordTy, 3), smemAddr);
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
store(wordVecs[linearWordIdx], smemAddr);
}
}
@@ -2693,7 +2702,7 @@ public:
for (int e = 0; e < 4; ++e)
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
i8Elems[m][e], i32_val(e));
i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]);
i32Elems[m] = bitcast(i8v4Elems[m], i32_ty);
}
} else { // k first
Value offset = i32_val(sOffsetElem);
@@ -2711,7 +2720,7 @@ public:
for (int e = 0; e < 4; ++e)
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
i8Elems[m][e], i32_val(e));
i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]);
i32Elems[m] = bitcast(i8v4Elems[m], i32_ty);
}
}
@@ -2823,10 +2832,7 @@ private:
ConversionPatternRewriter &rewriter) const;
/// Convert to mma.m8n8k4
LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
assert(false && "Not implemented yet.");
return failure();
}
ConversionPatternRewriter &rewriter) const;
LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -2835,48 +2841,127 @@ private:
}
};
struct DotOpConversionHelper {
// Helper for conversion of DotOp with mma<version=1>, that is sm<80
struct DotOpMmaV1ConversionHelper {
MmaEncodingAttr mmaLayout;
ArrayRef<unsigned> wpt;
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
explicit DotOpMmaV1ConversionHelper(MmaEncodingAttr mmaLayout)
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {}
int getRepM(int M) const {
return std::max<int>(M / (wpt[0] * instrShape[0]), 1);
}
int getRepN(int N) const {
return std::max<int>(N / (wpt[1] * instrShape[1]), 1);
}
int getRepK(int K) const { return std::max<int>(K / instrShape[2], 1); }
static ArrayRef<unsigned> getMmaInstrShape() { return instrShape; }
static Type getMmaRetType(TensorType operand) {
auto *ctx = operand.getContext();
Type fp32Ty = type::f32Ty(ctx);
// f16*f16+f32->f32
return struct_ty(SmallVector<Type>{8, fp32Ty});
}
// number of fp16x2 elements for $a.
int numElemsPerThreadA(RankedTensorType tensorTy) const {
auto shape = tensorTy.getShape();
auto order = getOrder();
bool isARow = order[0] != 0;
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
int packSize0 = (isARow || isAVec4) ? 1 : 2;
SmallVector<int> fpw({2, 2, 1});
int repM = 2 * packSize0;
int repK = 1;
int spwM = fpw[0] * 4 * repM;
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
int NK = shape[1];
unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]);
// NOTE We cound't get the vec from the shared layout.
// int vecA = sharedLayout.getVec();
// TODO[Superjomn]: Consider the case when vecA > 4
bool vecGt4 = false;
int elemsPerLd = vecGt4 ? 4 : 2;
return (numM / 2) * (NK / 4) * elemsPerLd;
}
// number of fp16x2 elements for $b.
int numElemsPerThreadB(RankedTensorType tensorTy) const {
auto shape = tensorTy.getShape();
auto order = getOrder();
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
// NOTE We cound't get the vec from the shared layout.
// int vecB = sharedLayout.getVec();
// TODO[Superjomn]: Consider the case when vecA > 4
bool vecGt4 = false;
int elemsPerLd = vecGt4 ? 4 : 2;
int NK = shape[0];
unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]);
return (numN / 2) * (NK / 4) * elemsPerLd;
}
// Loading $a from smem to registers, returns a LLVM::Struct.
Value loadA(Value A, Value llA, Value thread, Value smem, Location loc,
ConversionPatternRewriter &rewriter) const;
// Loading $b from smem to registers, returns a LLVM::Struct.
Value loadB(Value B, Value llB, Value thread, Value smem, Location loc,
ConversionPatternRewriter &rewriter) const;
// Loading $c to registers, returns a LLVM::Struct.
Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const;
static ArrayRef<unsigned> getOrder() { return mmaOrder; }
// Compute the offset of the matrix to load.
// Returns offsetAM, offsetAK, offsetBN, offsetBK.
// NOTE, the information M(from $a) and N(from $b) couldn't be retrieved at
// the same time in the usage in convert_layout[shared->dot_op], we leave the
// noexist info to be 0 and only use the desired argument from the composed
// result. In this way we want to retain the original code structure in
// convert_mma884 method for easier debugging.
std::tuple<Value, Value, Value, Value>
computeOffsets(Value threadId, bool isARow, bool isBRow, ArrayRef<int> fpw,
ArrayRef<int> spw, ArrayRef<int> rep,
ConversionPatternRewriter &rewriter, Location loc) const;
// Extract values belong to $a or $b from a LLVMStruct, the shape is n0xn1.
ValueTable extractLoadedOperand(Value llStruct, int n0, int n1,
ConversionPatternRewriter &rewriter) const;
private:
static constexpr unsigned instrShape[] = {16, 16, 4};
static constexpr unsigned mmaOrder[] = {0, 1};
};
// Helper for conversion of DotOp with mma<version=2>, that is sm>=80
struct DotOpMmaV2ConversionHelper {
using TensorCoreType = DotOpConversion::TensorCoreType;
MmaEncodingAttr mmaLayout;
MLIRContext *ctx{};
explicit DotOpConversionHelper(MmaEncodingAttr mmaLayout)
explicit DotOpMmaV2ConversionHelper(MmaEncodingAttr mmaLayout)
: mmaLayout(mmaLayout) {
ctx = mmaLayout.getContext();
}
// Load SplatLike C which contains a constVal. It simply returns 4 fp32
// constVal.
SmallVector<Value> loadSplatLikeC(Value C, Location loc,
ConversionPatternRewriter &rewriter) const {
assert(isSplatLike(C));
int numRes = getMmaInstrShape()[0] * getMmaInstrShape()[1] / 32;
if (auto constv = llvm::dyn_cast<arith::ConstantOp>(C.getDefiningOp())) {
if (auto attr = constv.getValue().dyn_cast<SplatElementsAttr>()) {
Type elemType = attr.getElementType();
if (elemType.isInteger(32)) {
int v = attr.getSplatValue<int>();
return SmallVector<Value>(numRes, i32_val(v));
} else if (elemType.isInteger(8)) {
int v = attr.getSplatValue<int8_t>();
auto newv = rewriter.create<arith::ConstantOp>(
loc, elemType, IntegerAttr::get(elemType, v));
return SmallVector<Value>(numRes, newv);
} else if (elemType.isF32()) {
int v = attr.getSplatValue<float>();
auto newv = rewriter.create<arith::ConstantOp>(
loc, elemType, FloatAttr::get(elemType, v));
return SmallVector<Value>(numRes, newv);
}
}
}
assert(false && "Not supported type.");
return {};
}
void deduceMmaType(DotOp op) const { mmaType = getMmaType(op); }
void deduceMmaType(Type operandTy) const {
mmaType = getTensorCoreTypeFromOperand(operandTy);
@@ -2884,8 +2969,8 @@ struct DotOpConversionHelper {
// Get the M and N of mat instruction shape.
static std::tuple<int, int> getMatShapeMN() {
// According to DotOpConversionHelper::mmaMatShape, all the matrix shape's
// M,N are {8,8}
// According to DotOpMmaV2ConversionHelper::mmaMatShape, all the matrix
// shape's M,N are {8,8}
return {8, 8};
}
@@ -3143,7 +3228,7 @@ struct MMA16816ConversionHelper {
Value thread, lane, warp, warpMN, warpN, warpM;
DotOpConversionHelper helper;
DotOpMmaV2ConversionHelper helper;
ConversionPatternRewriter &rewriter;
TypeConverter *typeConverter;
Location loc;
@@ -3203,22 +3288,25 @@ struct MMA16816ConversionHelper {
static int getNumRepM(Type operand, int M, int wpt) {
auto tensorCoreType =
DotOpConversionHelper::getTensorCoreTypeFromOperand(operand);
int mmaInstrM = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[0];
DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand);
int mmaInstrM =
DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[0];
return std::max<int>(M / (wpt * mmaInstrM), 1);
}
static int getNumRepN(Type operand, int N, int wpt) {
auto tensorCoreType =
DotOpConversionHelper::getTensorCoreTypeFromOperand(operand);
int mmaInstrN = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[1];
DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand);
int mmaInstrN =
DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[1];
return std::max<int>(N / (wpt * mmaInstrN), 1);
}
static int getNumRepK_(Type operand, int K) {
auto tensorCoreType =
DotOpConversionHelper::getTensorCoreTypeFromOperand(operand);
int mmaInstrK = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[2];
DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand);
int mmaInstrK =
DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[2];
return std::max<int>(K / mmaInstrK, 1);
}
@@ -3304,7 +3392,7 @@ struct MMA16816ConversionHelper {
// Loading $c to registers, returns a Value.
Value loadC(Value tensor, Value llTensor) const {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto [repM, repN] = DotOpConversionHelper::getRepMN(tensorTy);
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy);
size_t fcSize = 4 * repM * repN;
assert(tensorTy.getEncoding().isa<MmaEncodingAttr>() &&
@@ -3371,7 +3459,7 @@ struct MMA16816ConversionHelper {
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
};
for (int i = 0; i < 4; i++)
for (int i = 0; i < 4; ++i)
fc[m * colsPerThread + 4 * n + i] =
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
};
@@ -3427,7 +3515,7 @@ private:
Type smemPtrTy = helper.getShemPtrTy();
for (int i = 0; i < numPtrs; ++i) {
ptrs[i] =
bitcast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
bitcast(gep(smemPtrTy, llTensor, ValueRange({offs[i]})), smemPtrTy);
}
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
@@ -3492,7 +3580,7 @@ private:
int offset{};
ValueTable vals;
for (int i = 0; i < n0; i++) {
for (int i = 0; i < n0; ++i) {
for (int j = 0; j < n1; j++) {
vals[{2 * i, 2 * j}] = elems[offset++];
vals[{2 * i, 2 * j + 1}] = elems[offset++];
@@ -3514,20 +3602,37 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
auto dotOperandLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
MmaEncodingAttr mmaLayout =
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
assert(mmaLayout);
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
rewriter, getTypeConverter(), op.getLoc());
Value res;
if (dotOperandLayout.getOpIdx() == 0) {
// operand $a
res = mmaHelper.loadA(src, adaptor.src());
} else if (dotOperandLayout.getOpIdx() == 1) {
// operand $b
res = mmaHelper.loadB(src, adaptor.src());
if (mmaLayout.getVersion() == 2) {
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
rewriter, getTypeConverter(),
op.getLoc());
if (dotOperandLayout.getOpIdx() == 0) {
// operand $a
res = mmaHelper.loadA(src, adaptor.src());
} else if (dotOperandLayout.getOpIdx() == 1) {
// operand $b
res = mmaHelper.loadB(src, adaptor.src());
}
} else if (mmaLayout.getVersion() == 1) {
DotOpMmaV1ConversionHelper helper(mmaLayout);
if (dotOperandLayout.getOpIdx() == 0) {
// operand $a
res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc),
adaptor.src(), loc, rewriter);
} else if (dotOperandLayout.getOpIdx() == 1) {
// operand $b
res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc),
adaptor.src(), loc, rewriter);
}
} else {
assert(false && "Unsupported mma layout found");
}
rewriter.replaceOp(op, res);
@@ -3571,6 +3676,424 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
adaptor);
}
// Simply port the old code here to avoid large difference and make debugging
// and profiling easier.
LogicalResult
DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *ctx = op.getContext();
auto loc = op.getLoc();
Value A = op.a();
Value B = op.b();
Value D = op.getResult();
auto mmaLayout = D.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<MmaEncodingAttr>();
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
auto DTensorTy = D.getType().cast<RankedTensorType>();
auto AShape = ATensorTy.getShape();
auto BShape = BTensorTy.getShape();
auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
bool transA = op.transA();
bool transB = op.transB();
bool isARow = !transA;
bool isBRow = !transB;
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
int packSize0 = (isARow || isAVec4) ? 1 : 2;
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({2 * packSize0, 2 * packSize1, 1});
SmallVector<int> spw({fpw[0] * 4 * rep[0], fpw[1] * 4 * rep[1], 1});
Value loadedA = adaptor.a();
Value loadedB = adaptor.b();
Value loadedC = adaptor.c();
DotOpMmaV1ConversionHelper helper(mmaLayout);
unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]);
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[0]);
unsigned NK = AShape[1];
auto has = helper.extractLoadedOperand(loadedA, numM / 2, NK, rewriter);
auto hbs = helper.extractLoadedOperand(loadedB, numN / 2, NK, rewriter);
size_t accSize = numM * numN;
// initialize accumulators
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
auto ha = has[{m, k}];
auto hb = hbs[{n, k}];
std::vector<size_t> idx{{
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
(m * 2 + 0) + (n * 4 + 1) * numM,
(m * 2 + 1) + (n * 4 + 0) * numM, // row1
(m * 2 + 1) + (n * 4 + 1) * numM,
(m * 2 + 0) + (n * 4 + 2) * numM, // row2
(m * 2 + 0) + (n * 4 + 3) * numM,
(m * 2 + 1) + (n * 4 + 2) * numM, // row3
(m * 2 + 1) + (n * 4 + 3) * numM,
}};
PTXBuilder builder;
auto *resOprs = builder.newListOperand(8, "=f");
auto *AOprs = builder.newListOperand({
{ha.first, "f"},
{ha.second, "f"},
});
auto *BOprs = builder.newListOperand({
{hb.first, "f"},
{hb.second, "f"},
});
auto *COprs = builder.newListOperand();
for (int i = 0; i < 8; ++i)
COprs->listAppend(builder.newOperand(acc[idx[i]], std::to_string(i)));
auto mma = builder.create("mma.sync.aligned.m8n8k4")
->o(isARow ? "row" : "col")
.o(isBRow ? "row" : "col")
.o(".f32.f16.f16.f32");
mma(resOprs, AOprs, BOprs, COprs);
Value res = builder.launch(rewriter, loc, helper.getMmaRetType(ATensorTy));
auto getIntAttr = [&](int v) {
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
};
for (unsigned i = 0; i < 8; i++)
acc[idx[i]] = extract_val(f32_ty, res, getIntAttr(i));
};
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
for (unsigned n = 0; n < numN / 2; ++n) {
callMMA(m, n, k);
}
// replace with new packed result
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(acc.size(), type::f32Ty(ctx)));
Value res = getStructFromElements(loc, acc, rewriter, structTy);
rewriter.replaceOp(op, res);
return success();
}
Value DotOpMmaV1ConversionHelper::loadA(
Value tensor, Value llTensor, Value thread, Value smem, Location loc,
ConversionPatternRewriter &rewriter) const {
auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = tensorTy.getShape();
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
auto order = sharedLayout.getOrder();
bool isARow = order[0] != 0;
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
int packSize0 = (isARow || isAVec4) ? 1 : 2;
SmallVector<int> fpw({2, 2, 1});
int repM = 2 * packSize0;
int repK = 1;
int spwM = fpw[0] * 4 * repM;
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
int vecA = sharedLayout.getVec();
int strideAM = isARow ? shape[1] : 1;
int strideAK = isARow ? 1 : shape[0];
int strideA0 = isARow ? strideAK : strideAM;
int strideA1 = isARow ? strideAM : strideAK;
int strideRepM = wpt[0] * fpw[0] * 8;
int strideRepK = 1;
auto [offsetAM, offsetAK, _0, _1] =
computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc);
// swizzling
int perPhaseA = sharedLayout.getPerPhase();
int maxPhaseA = sharedLayout.getMaxPhase();
int stepA0 = isARow ? strideRepK : strideRepM;
int numPtrA = std::max(2 * perPhaseA * maxPhaseA / stepA0, 1);
int NK = shape[1];
// pre-compute pointer lanes
Value offA0 = isARow ? offsetAK : offsetAM;
Value offA1 = isARow ? offsetAM : offsetAK;
Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA));
SmallVector<Value> offA(numPtrA);
for (int i = 0; i < numPtrA; i++) {
Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM)));
offA0I = udiv(offA0I, i32_val(vecA));
offA0I = xor_(offA0I, phaseA);
offA0I = xor_(offA0I, i32_val(vecA));
offA[i] =
add(mul(offA0I, i32_val(strideA0)), mul(offA1, i32_val(strideA1)));
}
Type f16x2Ty = vec_ty(f16_ty, 2);
// One thread get 8 elements as result
Type retTy =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector(8, type::f32Ty(ctx)));
// prepare arguments
SmallVector<Value> ptrA(numPtrA);
std::map<std::pair<int, int>, std::pair<Value, Value>> has;
for (int i = 0; i < numPtrA; i++)
ptrA[i] = gep(ptr_ty(f16_ty), smem, offA[i]);
auto instrShape = getMmaInstrShape();
unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]);
Type f16PtrTy = ptr_ty(f16_ty);
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
vals[{m, k}] = {val0, val1};
};
auto loadA = [&](int m, int k) {
int offidx = (isARow ? k / 4 : m) % numPtrA;
Value thePtrA = gep(f16PtrTy, smem, offA[offidx]);
int stepAM = isARow ? m : m / numPtrA * numPtrA;
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
Value pa = gep(f16PtrTy, thePtrA,
i32_val(stepAM * strideRepM * strideAM + stepAK * strideAK));
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
Value ha = load(bitcast(pa, aPtrTy));
// record lds that needs to be moved
Value ha00 = bitcast(extract_element(ha, i32_val(0)), f16x2Ty);
Value ha01 = bitcast(extract_element(ha, i32_val(1)), f16x2Ty);
ld(has, m, k, ha00, ha01);
if (vecA > 4) {
Value ha10 = bitcast(extract_element(ha, i32_val(2)), f16x2Ty);
Value ha11 = bitcast(extract_element(ha, i32_val(3)), f16x2Ty);
if (isARow)
ld(has, m, k + 4, ha10, ha11);
else
ld(has, m + 1, k, ha10, ha11);
}
};
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
if (!has.count({m, k}))
loadA(m, k);
SmallVector<Value> elems;
elems.reserve(has.size() * 2);
auto vecTy = vec_ty(f16_ty, 2);
for (auto item : has) { // has is a map, the key should be ordered.
elems.push_back(item.second.first);
elems.push_back(item.second.second);
}
Type resTy = struct_ty(SmallVector<Type>(elems.size(), f16x2Ty));
Value res = getStructFromElements(loc, elems, rewriter, resTy);
return res;
}
Value DotOpMmaV1ConversionHelper::loadB(
Value tensor, Value llTensor, Value thread, Value smem, Location loc,
ConversionPatternRewriter &rewriter) const {
auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = tensorTy.getShape();
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
auto order = sharedLayout.getOrder();
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
int vecB = sharedLayout.getVec();
int strideBN = isBRow ? 1 : shape[0];
int strideBK = isBRow ? shape[1] : 1;
int strideB0 = isBRow ? strideBN : strideBK;
int strideB1 = isBRow ? strideBK : strideBN;
int strideRepN = wpt[1] * fpw[1] * 8;
int strideRepK = 1;
// swizzling
int perPhaseA = sharedLayout.getPerPhase();
int maxPhaseA = sharedLayout.getMaxPhase();
int perPhaseB = sharedLayout.getPerPhase();
int maxPhaseB = sharedLayout.getMaxPhase();
int stepB0 = isBRow ? strideRepN : strideRepK;
int numPtrB = std::max(2 * perPhaseB * maxPhaseB / stepB0, 1);
int NK = shape[0];
auto [_0, _1, offsetBN, offsetBK] =
computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc);
Value offB0 = isBRow ? offsetBN : offsetBK;
Value offB1 = isBRow ? offsetBK : offsetBN;
Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB));
SmallVector<Value> offB(numPtrB);
for (int i = 0; i < numPtrB; ++i) {
Value offB0I = add(offB0, i32_val(i * (isBRow ? strideRepN : 4)));
offB0I = udiv(offB0I, i32_val(vecB));
offB0I = xor_(offB0I, phaseB);
offB0I = mul(offB0I, i32_val(vecB));
offB[i] =
add(mul(offB0I, i32_val(strideB0)), mul(offB1, i32_val(strideB1)));
}
Type f16PtrTy = ptr_ty(f16_ty);
Type f16x2Ty = vec_ty(f16_ty, 2);
SmallVector<Value> ptrB(numPtrB);
ValueTable hbs;
for (int i = 0; i < numPtrB; ++i)
ptrB[i] = gep(ptr_ty(f16_ty), smem, offB[i]);
auto ld = [&](decltype(hbs) &vals, int m, int k, Value val0, Value val1) {
vals[{m, k}] = {val0, val1};
};
auto loadB = [&](int n, int K) {
int offidx = (isBRow ? n : K / 4) % numPtrB;
Value thePtrB = ptrB[offidx];
int stepBN = isBRow ? n / numPtrB * numPtrB : n;
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
Value pb = gep(f16PtrTy, thePtrB,
i32_val(stepBN * strideRepN * strideBN + stepBK * strideBK));
Value hb =
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
// record lds that needs to be moved
Value hb00 = bitcast(extract_element(hb, i32_val(0)), f16x2Ty);
Value hb01 = bitcast(extract_element(hb, i32_val(1)), f16x2Ty);
ld(hbs, n, K, hb00, hb01);
if (vecB > 4) {
Value hb10 = bitcast(extract_element(hb, i32_val(2)), f16x2Ty);
Value hb11 = bitcast(extract_element(hb, i32_val(3)), f16x2Ty);
if (isBRow)
ld(hbs, n + 1, K, hb10, hb11);
else
ld(hbs, n, K + 4, hb10, hb11);
}
};
unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]);
for (unsigned k = 0; k < NK; k += 4)
for (unsigned n = 0; n < numN / 2; ++n) {
if (!hbs.count({n, k}))
loadB(n, k);
}
SmallVector<Value> elems;
for (auto &item : hbs) { // has is a map, the key should be ordered.
elems.push_back(item.second.first);
elems.push_back(item.second.second);
}
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
Type resTy = struct_ty(SmallVector<Type>(elems.size(), fp16x2Ty));
Value res = getStructFromElements(loc, elems, rewriter, resTy);
return res;
}
Value DotOpMmaV1ConversionHelper::loadC(
Value tensor, Value llTensor, ConversionPatternRewriter &rewriter) const {
return llTensor;
}
std::tuple<Value, Value, Value, Value>
DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow,
bool isBRow, ArrayRef<int> fpw,
ArrayRef<int> spw, ArrayRef<int> rep,
ConversionPatternRewriter &rewriter,
Location loc) const {
auto *ctx = rewriter.getContext();
Value _1 = i32_val(1);
Value _3 = i32_val(3);
Value _4 = i32_val(4);
Value _16 = i32_val(16);
Value _32 = i32_val(32);
Value lane = urem(threadId, _32);
Value warp = udiv(threadId, _32);
// warp offset
Value warp0 = urem(warp, i32_val(wpt[0]));
Value warp12 = udiv(warp, i32_val(wpt[0]));
Value warp1 = urem(warp12, i32_val(wpt[1]));
Value warpMOff = mul(warp0, i32_val(spw[0]));
Value warpNOff = mul(warp1, i32_val(spw[1]));
// Quad offset
Value quadMOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[0]));
Value quadNOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[1]));
// Pair offset
Value pairMOff = udiv(urem(lane, _16), _4);
pairMOff = urem(pairMOff, i32_val(fpw[0]));
pairMOff = mul(pairMOff, _4);
Value pairNOff = udiv(urem(lane, _16), _4);
pairNOff = udiv(pairNOff, i32_val(fpw[0]));
pairNOff = urem(pairNOff, i32_val(fpw[1]));
pairNOff = mul(pairNOff, _4);
// scale
pairMOff = mul(pairMOff, i32_val(rep[0] / 2));
quadMOff = mul(quadMOff, i32_val(rep[0] / 2));
pairNOff = mul(pairNOff, i32_val(rep[1] / 2));
quadNOff = mul(quadNOff, i32_val(rep[1] / 2));
// Quad pair offset
Value laneMOff = add(pairMOff, quadMOff);
Value laneNOff = add(pairNOff, quadNOff);
// A offset
Value offsetAM = add(warpMOff, laneMOff);
Value offsetAK = and_(lane, _3);
// B offset
Value offsetBN = add(warpNOff, laneNOff);
Value offsetBK = and_(lane, _3);
// i indices
Value offsetCM = add(and_(lane, _1), offsetAM);
if (isARow) {
offsetAM = add(offsetAM, urem(threadId, _4));
offsetAK = i32_val(0);
}
if (!isBRow) {
offsetBN = add(offsetBN, urem(threadId, _4));
offsetBK = i32_val(0);
}
return std::make_tuple(offsetAM, offsetAK, offsetBN, offsetBK);
}
DotOpMmaV1ConversionHelper::ValueTable
DotOpMmaV1ConversionHelper::extractLoadedOperand(
Value llStruct, int n0, int n1, ConversionPatternRewriter &rewriter) const {
ValueTable rcds;
SmallVector<Value> elems =
ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct(
llStruct.getLoc(), llStruct, rewriter);
int offset = 0;
for (int i = 0; i < n0; ++i)
for (int k = 0; k < n1; k += 4) {
rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]);
offset += 2;
}
return rcds;
}
/// ====================== mma codegen end ============================
Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
@@ -3579,9 +4102,10 @@ Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto shape = tensorTy.getShape();
if (layout.getVersion() == 2) {
auto tensorTy = resType.cast<RankedTensorType>();
auto [repM, repN] = DotOpConversionHelper::getRepMN(tensorTy);
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy);
size_t fcSize = 4 * repM * repN;
auto structTy = LLVM::LLVMStructType::getLiteral(
@@ -3589,6 +4113,18 @@ Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
return getStructFromElements(loc, SmallVector<Value>(fcSize, constVal),
rewriter, structTy);
}
if (layout.getVersion() == 1) {
DotOpMmaV1ConversionHelper helper(layout);
int repM = helper.getRepM(shape[0]);
int repN = helper.getRepN(shape[1]);
// According to mma layout of v1, each thread process 8 elements.
int elems = 8 * repM * repN;
auto structTy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), SmallVector<Type>(elems, elemType));
return getStructFromElements(loc, SmallVector<Value>(elems, constVal),
rewriter, structTy);
}
assert(false && "Unsupported mma layout found");
}
@@ -3620,6 +4156,7 @@ public:
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
auto ctx = type.getContext();
Attribute layout = type.getEncoding();
auto shape = type.getShape();
if (layout &&
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
layout.isa<MmaEncodingAttr>())) {
@@ -3632,12 +4169,21 @@ public:
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
} else if (auto mmaLayout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 2) {
auto [repM, repN] = DotOpConversionHelper::getRepMN(type);
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type);
size_t fcSize = 4 * repM * repN;
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fcSize, type.getElementType()));
}
if (mmaLayout.getVersion() == 1) {
DotOpMmaV1ConversionHelper helper(mmaLayout);
int repM = helper.getRepM(shape[0]);
int repN = helper.getRepN(shape[1]);
int elems = 8 * repM * repN;
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems, type.getElementType()));
}
llvm::errs()
<< "Unexpected mma layout detected in TritonToLLVMTypeConverter";
return llvm::None;
@@ -3645,9 +4191,9 @@ public:
} else if (auto dot_op_layout =
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
auto mmaLayout = dot_op_layout.getParent().cast<MmaEncodingAttr>();
auto wpt = mmaLayout.getWarpsPerCTA();
Type elemTy = type.getElementType();
if (mmaLayout.getVersion() == 2) {
auto wpt = mmaLayout.getWarpsPerCTA();
Type elemTy = type.getElementType();
if (dot_op_layout.getOpIdx() == 0) { // $a
int elems =
@@ -3660,8 +4206,22 @@ public:
int elems =
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt);
Type x2Ty = vec_ty(elemTy, 2);
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems, x2Ty));
return struct_ty(SmallVector<Type>(elems, x2Ty));
}
}
if (mmaLayout.getVersion() == 1) {
DotOpMmaV1ConversionHelper helper(mmaLayout);
if (dot_op_layout.getOpIdx() == 0) { // $a
int elems = helper.numElemsPerThreadA(type);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
}
if (dot_op_layout.getOpIdx() == 1) { // $b
int elems = helper.numElemsPerThreadB(type);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
}
}

View File

@@ -270,10 +270,23 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
assert(rank == 2 && "Unexpected rank of mma layout");
assert(getVersion() == 2 && "mmaLayout version = 1 is not implemented yet");
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
return elemsCol * elemsRow;
assert((getVersion() == 1 || getVersion() == 2) &&
"Only version 1 and 2 is supported");
int res = 0;
if (getVersion() == 1) {
unsigned mmasRow = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]);
unsigned mmasCol = ceil<unsigned>(shape[1], 16 * getWarpsPerCTA()[1]);
// Each warp-level mma884 will perform a m16xn16xk4 mma, thus get a m16xn16
// matrix as result.
res = mmasRow * mmasCol * (16 * 16 / 32);
} else if (getVersion() == 2) {
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
res = elemsCol * elemsRow;
}
return res;
}
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {

View File

@@ -738,3 +738,28 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
return
}
}
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
%a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a>
%b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b>
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma>
// TODO[goostavz]: uncomment the following lines after convert_layout[mma<v1> -> blocked] is ready.
// %38 = triton_gpu.convert_layout %28 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked>
// %30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>, #blocked>
// %36 = tt.broadcast %30 : (tensor<128x1x!tt.ptr<f32>, #blocked>) -> tensor<128x256x!tt.ptr<f32>, #blocked>
// tt.store %36, %38 : tensor<128x256xf32, #blocked>
return
}
}