[Triton-MLIR][BACKEND] Port the mma<v1> conversion (#815)
This PR does - port the mma<v1> related code, and support dot conversion and convert_layout[shared->dot_op<mma<v1>>] - add a lit test for dot v1
This commit is contained in:
@@ -56,6 +56,12 @@ Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
|
||||
IntegerAttr::get(i32ty, v));
|
||||
}
|
||||
|
||||
Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
|
||||
auto type = type::f32Ty(rewriter.getContext());
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
||||
rewriter.getF32FloatAttr(v));
|
||||
}
|
||||
|
||||
// Create a index type constant.
|
||||
Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
|
||||
@@ -90,7 +96,8 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
|
||||
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
|
||||
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
|
||||
#define bitcast(...) rewriter.create<LLVM::BitcastOp>(loc, __VA_ARGS__)
|
||||
#define bitcast(val__, type__) \
|
||||
rewriter.create<LLVM::BitcastOp>(loc, type__, val__)
|
||||
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
|
||||
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
|
||||
#define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__)
|
||||
@@ -112,11 +119,13 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define f16_ty rewriter.getF16Type()
|
||||
#define i8_ty rewriter.getIntegerType(8)
|
||||
#define f32_ty rewriter.getF32Type()
|
||||
#define vec_ty(type, num) VectorType::get(num, type)
|
||||
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
|
||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__)
|
||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
|
||||
|
||||
// Creator for constant
|
||||
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
|
||||
@@ -685,7 +694,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = bitcast(srcType, constVal);
|
||||
auto llSrc = bitcast(constVal, srcType);
|
||||
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
||||
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
|
||||
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
|
||||
@@ -964,7 +973,7 @@ struct LoadOpConversion
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
||||
v = insert_element(vecTy, v, falseVal, sVal);
|
||||
}
|
||||
v = bitcast(IntegerType::get(getContext(), width), v);
|
||||
v = bitcast(v, IntegerType::get(getContext(), width));
|
||||
|
||||
PTXInstr::Operand *opr{};
|
||||
if (otherIsSplatConstInt)
|
||||
@@ -1001,9 +1010,8 @@ struct LoadOpConversion
|
||||
} else {
|
||||
curr = ret;
|
||||
}
|
||||
curr = bitcast(
|
||||
LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
|
||||
curr);
|
||||
curr = bitcast(curr, LLVM::getFixedVectorType(valueElemTy,
|
||||
width / valueElemNbits));
|
||||
rets.push_back(curr);
|
||||
}
|
||||
int tmp = width / valueElemNbits;
|
||||
@@ -1101,7 +1109,7 @@ struct StoreOpConversion
|
||||
Value elem = valueElems[elemOffset];
|
||||
if (elem.getType().isInteger(1))
|
||||
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
|
||||
elem = bitcast(valueElemTy, elem);
|
||||
elem = bitcast(elem, valueElemTy);
|
||||
|
||||
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
|
||||
llWord =
|
||||
@@ -1109,7 +1117,7 @@ struct StoreOpConversion
|
||||
rewriter.create<LLVM::ConstantOp>(
|
||||
loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx)));
|
||||
}
|
||||
llWord = bitcast(valArgTy, llWord);
|
||||
llWord = bitcast(llWord, valArgTy);
|
||||
std::string constraint =
|
||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
asmArgs.emplace_back(llWord, constraint);
|
||||
@@ -1328,7 +1336,7 @@ Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
|
||||
|
||||
if (bits == 64) {
|
||||
Type vecTy = vec_ty(f32_ty, 2);
|
||||
Value vec = bitcast(vecTy, val);
|
||||
Value vec = bitcast(val, vecTy);
|
||||
Value val0 = extract_element(f32_ty, vec, i32_val(0));
|
||||
Value val1 = extract_element(f32_ty, vec, i32_val(1));
|
||||
val0 = shflSync(rewriter, loc, val0, i);
|
||||
@@ -1336,7 +1344,7 @@ Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
|
||||
vec = undef(vecTy);
|
||||
vec = insert_element(vecTy, vec, val0, i32_val(0));
|
||||
vec = insert_element(vecTy, vec, val1, i32_val(1));
|
||||
return bitcast(val.getType(), vec);
|
||||
return bitcast(vec, val.getType());
|
||||
}
|
||||
|
||||
PTXBuilder builder;
|
||||
@@ -1363,7 +1371,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
smemBase = bitcast(elemPtrTy, smemBase);
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
|
||||
auto smemShape = getScratchConfigForReduce(op);
|
||||
|
||||
@@ -1430,7 +1438,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
|
||||
barrier();
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (size_t i = 0; i < resultElems; i++) {
|
||||
for (unsigned i = 0; i < resultElems; ++i) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||
@@ -1469,7 +1477,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
smemBase = bitcast(elemPtrTy, smemBase);
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
|
||||
auto order = srcLayout.getOrder();
|
||||
unsigned sizeIntraWarps = threadsPerWarp[axis];
|
||||
@@ -1569,7 +1577,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
|
||||
barrier();
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (size_t i = 0; i < resultElems; i++) {
|
||||
for (size_t i = 0; i < resultElems; ++i) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, i32_val(0));
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||
@@ -2136,7 +2144,7 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
Value ptr = gep(elemPtrTy, smemBase, offset);
|
||||
auto vecTy = vec_ty(llvmElemTy, vec);
|
||||
ptr = bitcast(ptr_ty(vecTy, 3), ptr);
|
||||
ptr = bitcast(ptr, ptr_ty(vecTy, 3));
|
||||
if (stNotRd) {
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned v = 0; v < vec; ++v) {
|
||||
@@ -2175,7 +2183,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
smemBase = bitcast(elemPtrTy, smemBase);
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
auto shape = dstTy.getShape();
|
||||
unsigned rank = dstTy.getRank();
|
||||
SmallVector<unsigned> numReplicates(rank);
|
||||
@@ -2234,7 +2242,8 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
||||
}
|
||||
|
||||
SmallVector<Type> types(outElems, llvmElemTy);
|
||||
Type structTy = struct_ty(getContext(), types);
|
||||
auto *ctx = llvmElemTy.getContext();
|
||||
Type structTy = struct_ty(types);
|
||||
Value result = getStructFromElements(loc, outVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
@@ -2294,7 +2303,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
Value minVecVal = idx_val(minVec);
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
||||
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
||||
smemBase = bitcast(elemPtrTy, smemBase);
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
|
||||
SmallVector<Value> wordVecs(numWordsEachRep);
|
||||
// TODO: We should get less barriers if it is handled by membar pass
|
||||
@@ -2350,7 +2359,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
|
||||
// step 3: store
|
||||
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
||||
smemAddr = bitcast(ptr_ty(wordTy, 3), smemAddr);
|
||||
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
|
||||
store(wordVecs[linearWordIdx], smemAddr);
|
||||
}
|
||||
}
|
||||
@@ -2693,7 +2702,7 @@ public:
|
||||
for (int e = 0; e < 4; ++e)
|
||||
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
|
||||
i8Elems[m][e], i32_val(e));
|
||||
i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]);
|
||||
i32Elems[m] = bitcast(i8v4Elems[m], i32_ty);
|
||||
}
|
||||
} else { // k first
|
||||
Value offset = i32_val(sOffsetElem);
|
||||
@@ -2711,7 +2720,7 @@ public:
|
||||
for (int e = 0; e < 4; ++e)
|
||||
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
|
||||
i8Elems[m][e], i32_val(e));
|
||||
i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]);
|
||||
i32Elems[m] = bitcast(i8v4Elems[m], i32_ty);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2823,10 +2832,7 @@ private:
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
/// Convert to mma.m8n8k4
|
||||
LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
assert(false && "Not implemented yet.");
|
||||
return failure();
|
||||
}
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
@@ -2835,48 +2841,127 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
struct DotOpConversionHelper {
|
||||
// Helper for conversion of DotOp with mma<version=1>, that is sm<80
|
||||
struct DotOpMmaV1ConversionHelper {
|
||||
MmaEncodingAttr mmaLayout;
|
||||
ArrayRef<unsigned> wpt;
|
||||
|
||||
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
|
||||
|
||||
explicit DotOpMmaV1ConversionHelper(MmaEncodingAttr mmaLayout)
|
||||
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {}
|
||||
|
||||
int getRepM(int M) const {
|
||||
return std::max<int>(M / (wpt[0] * instrShape[0]), 1);
|
||||
}
|
||||
int getRepN(int N) const {
|
||||
return std::max<int>(N / (wpt[1] * instrShape[1]), 1);
|
||||
}
|
||||
int getRepK(int K) const { return std::max<int>(K / instrShape[2], 1); }
|
||||
|
||||
static ArrayRef<unsigned> getMmaInstrShape() { return instrShape; }
|
||||
|
||||
static Type getMmaRetType(TensorType operand) {
|
||||
auto *ctx = operand.getContext();
|
||||
Type fp32Ty = type::f32Ty(ctx);
|
||||
// f16*f16+f32->f32
|
||||
return struct_ty(SmallVector<Type>{8, fp32Ty});
|
||||
}
|
||||
|
||||
// number of fp16x2 elements for $a.
|
||||
int numElemsPerThreadA(RankedTensorType tensorTy) const {
|
||||
auto shape = tensorTy.getShape();
|
||||
auto order = getOrder();
|
||||
|
||||
bool isARow = order[0] != 0;
|
||||
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
int repM = 2 * packSize0;
|
||||
int repK = 1;
|
||||
int spwM = fpw[0] * 4 * repM;
|
||||
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
|
||||
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
|
||||
|
||||
int NK = shape[1];
|
||||
unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]);
|
||||
|
||||
// NOTE We cound't get the vec from the shared layout.
|
||||
// int vecA = sharedLayout.getVec();
|
||||
// TODO[Superjomn]: Consider the case when vecA > 4
|
||||
bool vecGt4 = false;
|
||||
int elemsPerLd = vecGt4 ? 4 : 2;
|
||||
return (numM / 2) * (NK / 4) * elemsPerLd;
|
||||
}
|
||||
|
||||
// number of fp16x2 elements for $b.
|
||||
int numElemsPerThreadB(RankedTensorType tensorTy) const {
|
||||
auto shape = tensorTy.getShape();
|
||||
auto order = getOrder();
|
||||
bool isBRow = order[0] != 0;
|
||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
||||
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
||||
// NOTE We cound't get the vec from the shared layout.
|
||||
// int vecB = sharedLayout.getVec();
|
||||
// TODO[Superjomn]: Consider the case when vecA > 4
|
||||
bool vecGt4 = false;
|
||||
int elemsPerLd = vecGt4 ? 4 : 2;
|
||||
int NK = shape[0];
|
||||
|
||||
unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]);
|
||||
return (numN / 2) * (NK / 4) * elemsPerLd;
|
||||
}
|
||||
|
||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||
Value loadA(Value A, Value llA, Value thread, Value smem, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||
Value loadB(Value B, Value llB, Value thread, Value smem, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// Loading $c to registers, returns a LLVM::Struct.
|
||||
Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
static ArrayRef<unsigned> getOrder() { return mmaOrder; }
|
||||
|
||||
// Compute the offset of the matrix to load.
|
||||
// Returns offsetAM, offsetAK, offsetBN, offsetBK.
|
||||
// NOTE, the information M(from $a) and N(from $b) couldn't be retrieved at
|
||||
// the same time in the usage in convert_layout[shared->dot_op], we leave the
|
||||
// noexist info to be 0 and only use the desired argument from the composed
|
||||
// result. In this way we want to retain the original code structure in
|
||||
// convert_mma884 method for easier debugging.
|
||||
std::tuple<Value, Value, Value, Value>
|
||||
computeOffsets(Value threadId, bool isARow, bool isBRow, ArrayRef<int> fpw,
|
||||
ArrayRef<int> spw, ArrayRef<int> rep,
|
||||
ConversionPatternRewriter &rewriter, Location loc) const;
|
||||
|
||||
// Extract values belong to $a or $b from a LLVMStruct, the shape is n0xn1.
|
||||
ValueTable extractLoadedOperand(Value llStruct, int n0, int n1,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
private:
|
||||
static constexpr unsigned instrShape[] = {16, 16, 4};
|
||||
static constexpr unsigned mmaOrder[] = {0, 1};
|
||||
};
|
||||
|
||||
// Helper for conversion of DotOp with mma<version=2>, that is sm>=80
|
||||
struct DotOpMmaV2ConversionHelper {
|
||||
using TensorCoreType = DotOpConversion::TensorCoreType;
|
||||
|
||||
MmaEncodingAttr mmaLayout;
|
||||
MLIRContext *ctx{};
|
||||
|
||||
explicit DotOpConversionHelper(MmaEncodingAttr mmaLayout)
|
||||
explicit DotOpMmaV2ConversionHelper(MmaEncodingAttr mmaLayout)
|
||||
: mmaLayout(mmaLayout) {
|
||||
ctx = mmaLayout.getContext();
|
||||
}
|
||||
|
||||
// Load SplatLike C which contains a constVal. It simply returns 4 fp32
|
||||
// constVal.
|
||||
SmallVector<Value> loadSplatLikeC(Value C, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
assert(isSplatLike(C));
|
||||
|
||||
int numRes = getMmaInstrShape()[0] * getMmaInstrShape()[1] / 32;
|
||||
if (auto constv = llvm::dyn_cast<arith::ConstantOp>(C.getDefiningOp())) {
|
||||
if (auto attr = constv.getValue().dyn_cast<SplatElementsAttr>()) {
|
||||
Type elemType = attr.getElementType();
|
||||
if (elemType.isInteger(32)) {
|
||||
int v = attr.getSplatValue<int>();
|
||||
return SmallVector<Value>(numRes, i32_val(v));
|
||||
} else if (elemType.isInteger(8)) {
|
||||
int v = attr.getSplatValue<int8_t>();
|
||||
auto newv = rewriter.create<arith::ConstantOp>(
|
||||
loc, elemType, IntegerAttr::get(elemType, v));
|
||||
return SmallVector<Value>(numRes, newv);
|
||||
} else if (elemType.isF32()) {
|
||||
int v = attr.getSplatValue<float>();
|
||||
auto newv = rewriter.create<arith::ConstantOp>(
|
||||
loc, elemType, FloatAttr::get(elemType, v));
|
||||
return SmallVector<Value>(numRes, newv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert(false && "Not supported type.");
|
||||
return {};
|
||||
}
|
||||
|
||||
void deduceMmaType(DotOp op) const { mmaType = getMmaType(op); }
|
||||
void deduceMmaType(Type operandTy) const {
|
||||
mmaType = getTensorCoreTypeFromOperand(operandTy);
|
||||
@@ -2884,8 +2969,8 @@ struct DotOpConversionHelper {
|
||||
|
||||
// Get the M and N of mat instruction shape.
|
||||
static std::tuple<int, int> getMatShapeMN() {
|
||||
// According to DotOpConversionHelper::mmaMatShape, all the matrix shape's
|
||||
// M,N are {8,8}
|
||||
// According to DotOpMmaV2ConversionHelper::mmaMatShape, all the matrix
|
||||
// shape's M,N are {8,8}
|
||||
return {8, 8};
|
||||
}
|
||||
|
||||
@@ -3143,7 +3228,7 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
Value thread, lane, warp, warpMN, warpN, warpM;
|
||||
|
||||
DotOpConversionHelper helper;
|
||||
DotOpMmaV2ConversionHelper helper;
|
||||
ConversionPatternRewriter &rewriter;
|
||||
TypeConverter *typeConverter;
|
||||
Location loc;
|
||||
@@ -3203,22 +3288,25 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
static int getNumRepM(Type operand, int M, int wpt) {
|
||||
auto tensorCoreType =
|
||||
DotOpConversionHelper::getTensorCoreTypeFromOperand(operand);
|
||||
int mmaInstrM = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[0];
|
||||
DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand);
|
||||
int mmaInstrM =
|
||||
DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[0];
|
||||
return std::max<int>(M / (wpt * mmaInstrM), 1);
|
||||
}
|
||||
|
||||
static int getNumRepN(Type operand, int N, int wpt) {
|
||||
auto tensorCoreType =
|
||||
DotOpConversionHelper::getTensorCoreTypeFromOperand(operand);
|
||||
int mmaInstrN = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[1];
|
||||
DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand);
|
||||
int mmaInstrN =
|
||||
DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[1];
|
||||
return std::max<int>(N / (wpt * mmaInstrN), 1);
|
||||
}
|
||||
|
||||
static int getNumRepK_(Type operand, int K) {
|
||||
auto tensorCoreType =
|
||||
DotOpConversionHelper::getTensorCoreTypeFromOperand(operand);
|
||||
int mmaInstrK = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[2];
|
||||
DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand);
|
||||
int mmaInstrK =
|
||||
DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[2];
|
||||
return std::max<int>(K / mmaInstrK, 1);
|
||||
}
|
||||
|
||||
@@ -3304,7 +3392,7 @@ struct MMA16816ConversionHelper {
|
||||
// Loading $c to registers, returns a Value.
|
||||
Value loadC(Value tensor, Value llTensor) const {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto [repM, repN] = DotOpConversionHelper::getRepMN(tensorTy);
|
||||
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy);
|
||||
size_t fcSize = 4 * repM * repN;
|
||||
|
||||
assert(tensorTy.getEncoding().isa<MmaEncodingAttr>() &&
|
||||
@@ -3371,7 +3459,7 @@ struct MMA16816ConversionHelper {
|
||||
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
|
||||
};
|
||||
|
||||
for (int i = 0; i < 4; i++)
|
||||
for (int i = 0; i < 4; ++i)
|
||||
fc[m * colsPerThread + 4 * n + i] =
|
||||
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
|
||||
};
|
||||
@@ -3427,7 +3515,7 @@ private:
|
||||
Type smemPtrTy = helper.getShemPtrTy();
|
||||
for (int i = 0; i < numPtrs; ++i) {
|
||||
ptrs[i] =
|
||||
bitcast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
|
||||
bitcast(gep(smemPtrTy, llTensor, ValueRange({offs[i]})), smemPtrTy);
|
||||
}
|
||||
|
||||
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
|
||||
@@ -3492,7 +3580,7 @@ private:
|
||||
|
||||
int offset{};
|
||||
ValueTable vals;
|
||||
for (int i = 0; i < n0; i++) {
|
||||
for (int i = 0; i < n0; ++i) {
|
||||
for (int j = 0; j < n1; j++) {
|
||||
vals[{2 * i, 2 * j}] = elems[offset++];
|
||||
vals[{2 * i, 2 * j + 1}] = elems[offset++];
|
||||
@@ -3514,20 +3602,37 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
|
||||
auto dotOperandLayout =
|
||||
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
|
||||
MmaEncodingAttr mmaLayout =
|
||||
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
|
||||
assert(mmaLayout);
|
||||
|
||||
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
||||
rewriter, getTypeConverter(), op.getLoc());
|
||||
|
||||
Value res;
|
||||
if (dotOperandLayout.getOpIdx() == 0) {
|
||||
// operand $a
|
||||
res = mmaHelper.loadA(src, adaptor.src());
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) {
|
||||
// operand $b
|
||||
res = mmaHelper.loadB(src, adaptor.src());
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
||||
rewriter, getTypeConverter(),
|
||||
op.getLoc());
|
||||
|
||||
if (dotOperandLayout.getOpIdx() == 0) {
|
||||
// operand $a
|
||||
res = mmaHelper.loadA(src, adaptor.src());
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) {
|
||||
// operand $b
|
||||
res = mmaHelper.loadB(src, adaptor.src());
|
||||
}
|
||||
} else if (mmaLayout.getVersion() == 1) {
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
if (dotOperandLayout.getOpIdx() == 0) {
|
||||
// operand $a
|
||||
res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc),
|
||||
adaptor.src(), loc, rewriter);
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) {
|
||||
// operand $b
|
||||
res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc),
|
||||
adaptor.src(), loc, rewriter);
|
||||
}
|
||||
} else {
|
||||
assert(false && "Unsupported mma layout found");
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, res);
|
||||
@@ -3571,6 +3676,424 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
|
||||
adaptor);
|
||||
}
|
||||
|
||||
// Simply port the old code here to avoid large difference and make debugging
|
||||
// and profiling easier.
|
||||
LogicalResult
|
||||
DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto *ctx = op.getContext();
|
||||
auto loc = op.getLoc();
|
||||
|
||||
Value A = op.a();
|
||||
Value B = op.b();
|
||||
Value D = op.getResult();
|
||||
auto mmaLayout = D.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<MmaEncodingAttr>();
|
||||
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
auto DTensorTy = D.getType().cast<RankedTensorType>();
|
||||
auto AShape = ATensorTy.getShape();
|
||||
auto BShape = BTensorTy.getShape();
|
||||
auto DShape = DTensorTy.getShape();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
bool transA = op.transA();
|
||||
bool transB = op.transB();
|
||||
|
||||
bool isARow = !transA;
|
||||
bool isBRow = !transB;
|
||||
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
||||
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
SmallVector<int> rep({2 * packSize0, 2 * packSize1, 1});
|
||||
SmallVector<int> spw({fpw[0] * 4 * rep[0], fpw[1] * 4 * rep[1], 1});
|
||||
|
||||
Value loadedA = adaptor.a();
|
||||
Value loadedB = adaptor.b();
|
||||
Value loadedC = adaptor.c();
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
|
||||
unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]);
|
||||
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[0]);
|
||||
unsigned NK = AShape[1];
|
||||
|
||||
auto has = helper.extractLoadedOperand(loadedA, numM / 2, NK, rewriter);
|
||||
auto hbs = helper.extractLoadedOperand(loadedB, numN / 2, NK, rewriter);
|
||||
|
||||
size_t accSize = numM * numN;
|
||||
|
||||
// initialize accumulators
|
||||
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
|
||||
|
||||
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
|
||||
auto ha = has[{m, k}];
|
||||
auto hb = hbs[{n, k}];
|
||||
std::vector<size_t> idx{{
|
||||
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
|
||||
(m * 2 + 0) + (n * 4 + 1) * numM,
|
||||
(m * 2 + 1) + (n * 4 + 0) * numM, // row1
|
||||
(m * 2 + 1) + (n * 4 + 1) * numM,
|
||||
(m * 2 + 0) + (n * 4 + 2) * numM, // row2
|
||||
(m * 2 + 0) + (n * 4 + 3) * numM,
|
||||
(m * 2 + 1) + (n * 4 + 2) * numM, // row3
|
||||
(m * 2 + 1) + (n * 4 + 3) * numM,
|
||||
}};
|
||||
|
||||
PTXBuilder builder;
|
||||
|
||||
auto *resOprs = builder.newListOperand(8, "=f");
|
||||
auto *AOprs = builder.newListOperand({
|
||||
{ha.first, "f"},
|
||||
{ha.second, "f"},
|
||||
});
|
||||
|
||||
auto *BOprs = builder.newListOperand({
|
||||
{hb.first, "f"},
|
||||
{hb.second, "f"},
|
||||
});
|
||||
auto *COprs = builder.newListOperand();
|
||||
for (int i = 0; i < 8; ++i)
|
||||
COprs->listAppend(builder.newOperand(acc[idx[i]], std::to_string(i)));
|
||||
|
||||
auto mma = builder.create("mma.sync.aligned.m8n8k4")
|
||||
->o(isARow ? "row" : "col")
|
||||
.o(isBRow ? "row" : "col")
|
||||
.o(".f32.f16.f16.f32");
|
||||
|
||||
mma(resOprs, AOprs, BOprs, COprs);
|
||||
|
||||
Value res = builder.launch(rewriter, loc, helper.getMmaRetType(ATensorTy));
|
||||
|
||||
auto getIntAttr = [&](int v) {
|
||||
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
|
||||
};
|
||||
for (unsigned i = 0; i < 8; i++)
|
||||
acc[idx[i]] = extract_val(f32_ty, res, getIntAttr(i));
|
||||
};
|
||||
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
callMMA(m, n, k);
|
||||
}
|
||||
|
||||
// replace with new packed result
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(acc.size(), type::f32Ty(ctx)));
|
||||
Value res = getStructFromElements(loc, acc, rewriter, structTy);
|
||||
rewriter.replaceOp(op, res);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Value DotOpMmaV1ConversionHelper::loadA(
|
||||
Value tensor, Value llTensor, Value thread, Value smem, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
bool isARow = order[0] != 0;
|
||||
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
int repM = 2 * packSize0;
|
||||
int repK = 1;
|
||||
int spwM = fpw[0] * 4 * repM;
|
||||
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
|
||||
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
|
||||
|
||||
int vecA = sharedLayout.getVec();
|
||||
|
||||
int strideAM = isARow ? shape[1] : 1;
|
||||
int strideAK = isARow ? 1 : shape[0];
|
||||
int strideA0 = isARow ? strideAK : strideAM;
|
||||
int strideA1 = isARow ? strideAM : strideAK;
|
||||
|
||||
int strideRepM = wpt[0] * fpw[0] * 8;
|
||||
int strideRepK = 1;
|
||||
|
||||
auto [offsetAM, offsetAK, _0, _1] =
|
||||
computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc);
|
||||
|
||||
// swizzling
|
||||
int perPhaseA = sharedLayout.getPerPhase();
|
||||
int maxPhaseA = sharedLayout.getMaxPhase();
|
||||
int stepA0 = isARow ? strideRepK : strideRepM;
|
||||
int numPtrA = std::max(2 * perPhaseA * maxPhaseA / stepA0, 1);
|
||||
int NK = shape[1];
|
||||
|
||||
// pre-compute pointer lanes
|
||||
Value offA0 = isARow ? offsetAK : offsetAM;
|
||||
Value offA1 = isARow ? offsetAM : offsetAK;
|
||||
Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA));
|
||||
SmallVector<Value> offA(numPtrA);
|
||||
|
||||
for (int i = 0; i < numPtrA; i++) {
|
||||
Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM)));
|
||||
offA0I = udiv(offA0I, i32_val(vecA));
|
||||
offA0I = xor_(offA0I, phaseA);
|
||||
offA0I = xor_(offA0I, i32_val(vecA));
|
||||
offA[i] =
|
||||
add(mul(offA0I, i32_val(strideA0)), mul(offA1, i32_val(strideA1)));
|
||||
}
|
||||
|
||||
Type f16x2Ty = vec_ty(f16_ty, 2);
|
||||
// One thread get 8 elements as result
|
||||
Type retTy =
|
||||
LLVM::LLVMStructType::getLiteral(ctx, SmallVector(8, type::f32Ty(ctx)));
|
||||
|
||||
// prepare arguments
|
||||
SmallVector<Value> ptrA(numPtrA);
|
||||
|
||||
std::map<std::pair<int, int>, std::pair<Value, Value>> has;
|
||||
for (int i = 0; i < numPtrA; i++)
|
||||
ptrA[i] = gep(ptr_ty(f16_ty), smem, offA[i]);
|
||||
|
||||
auto instrShape = getMmaInstrShape();
|
||||
unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]);
|
||||
|
||||
Type f16PtrTy = ptr_ty(f16_ty);
|
||||
|
||||
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
|
||||
vals[{m, k}] = {val0, val1};
|
||||
};
|
||||
auto loadA = [&](int m, int k) {
|
||||
int offidx = (isARow ? k / 4 : m) % numPtrA;
|
||||
Value thePtrA = gep(f16PtrTy, smem, offA[offidx]);
|
||||
|
||||
int stepAM = isARow ? m : m / numPtrA * numPtrA;
|
||||
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
|
||||
Value pa = gep(f16PtrTy, thePtrA,
|
||||
i32_val(stepAM * strideRepM * strideAM + stepAK * strideAK));
|
||||
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
|
||||
Value ha = load(bitcast(pa, aPtrTy));
|
||||
// record lds that needs to be moved
|
||||
Value ha00 = bitcast(extract_element(ha, i32_val(0)), f16x2Ty);
|
||||
Value ha01 = bitcast(extract_element(ha, i32_val(1)), f16x2Ty);
|
||||
ld(has, m, k, ha00, ha01);
|
||||
|
||||
if (vecA > 4) {
|
||||
Value ha10 = bitcast(extract_element(ha, i32_val(2)), f16x2Ty);
|
||||
Value ha11 = bitcast(extract_element(ha, i32_val(3)), f16x2Ty);
|
||||
if (isARow)
|
||||
ld(has, m, k + 4, ha10, ha11);
|
||||
else
|
||||
ld(has, m + 1, k, ha10, ha11);
|
||||
}
|
||||
};
|
||||
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
if (!has.count({m, k}))
|
||||
loadA(m, k);
|
||||
|
||||
SmallVector<Value> elems;
|
||||
elems.reserve(has.size() * 2);
|
||||
auto vecTy = vec_ty(f16_ty, 2);
|
||||
for (auto item : has) { // has is a map, the key should be ordered.
|
||||
elems.push_back(item.second.first);
|
||||
elems.push_back(item.second.second);
|
||||
}
|
||||
|
||||
Type resTy = struct_ty(SmallVector<Type>(elems.size(), f16x2Ty));
|
||||
Value res = getStructFromElements(loc, elems, rewriter, resTy);
|
||||
return res;
|
||||
}
|
||||
|
||||
Value DotOpMmaV1ConversionHelper::loadB(
|
||||
Value tensor, Value llTensor, Value thread, Value smem, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto order = sharedLayout.getOrder();
|
||||
bool isBRow = order[0] != 0;
|
||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
||||
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
||||
int vecB = sharedLayout.getVec();
|
||||
int strideBN = isBRow ? 1 : shape[0];
|
||||
int strideBK = isBRow ? shape[1] : 1;
|
||||
int strideB0 = isBRow ? strideBN : strideBK;
|
||||
int strideB1 = isBRow ? strideBK : strideBN;
|
||||
int strideRepN = wpt[1] * fpw[1] * 8;
|
||||
int strideRepK = 1;
|
||||
|
||||
// swizzling
|
||||
int perPhaseA = sharedLayout.getPerPhase();
|
||||
int maxPhaseA = sharedLayout.getMaxPhase();
|
||||
int perPhaseB = sharedLayout.getPerPhase();
|
||||
int maxPhaseB = sharedLayout.getMaxPhase();
|
||||
int stepB0 = isBRow ? strideRepN : strideRepK;
|
||||
int numPtrB = std::max(2 * perPhaseB * maxPhaseB / stepB0, 1);
|
||||
int NK = shape[0];
|
||||
|
||||
auto [_0, _1, offsetBN, offsetBK] =
|
||||
computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc);
|
||||
|
||||
Value offB0 = isBRow ? offsetBN : offsetBK;
|
||||
Value offB1 = isBRow ? offsetBK : offsetBN;
|
||||
Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB));
|
||||
SmallVector<Value> offB(numPtrB);
|
||||
for (int i = 0; i < numPtrB; ++i) {
|
||||
Value offB0I = add(offB0, i32_val(i * (isBRow ? strideRepN : 4)));
|
||||
offB0I = udiv(offB0I, i32_val(vecB));
|
||||
offB0I = xor_(offB0I, phaseB);
|
||||
offB0I = mul(offB0I, i32_val(vecB));
|
||||
offB[i] =
|
||||
add(mul(offB0I, i32_val(strideB0)), mul(offB1, i32_val(strideB1)));
|
||||
}
|
||||
|
||||
Type f16PtrTy = ptr_ty(f16_ty);
|
||||
Type f16x2Ty = vec_ty(f16_ty, 2);
|
||||
|
||||
SmallVector<Value> ptrB(numPtrB);
|
||||
ValueTable hbs;
|
||||
for (int i = 0; i < numPtrB; ++i)
|
||||
ptrB[i] = gep(ptr_ty(f16_ty), smem, offB[i]);
|
||||
|
||||
auto ld = [&](decltype(hbs) &vals, int m, int k, Value val0, Value val1) {
|
||||
vals[{m, k}] = {val0, val1};
|
||||
};
|
||||
|
||||
auto loadB = [&](int n, int K) {
|
||||
int offidx = (isBRow ? n : K / 4) % numPtrB;
|
||||
Value thePtrB = ptrB[offidx];
|
||||
|
||||
int stepBN = isBRow ? n / numPtrB * numPtrB : n;
|
||||
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
|
||||
Value pb = gep(f16PtrTy, thePtrB,
|
||||
i32_val(stepBN * strideRepN * strideBN + stepBK * strideBK));
|
||||
Value hb =
|
||||
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
|
||||
// record lds that needs to be moved
|
||||
Value hb00 = bitcast(extract_element(hb, i32_val(0)), f16x2Ty);
|
||||
Value hb01 = bitcast(extract_element(hb, i32_val(1)), f16x2Ty);
|
||||
ld(hbs, n, K, hb00, hb01);
|
||||
if (vecB > 4) {
|
||||
Value hb10 = bitcast(extract_element(hb, i32_val(2)), f16x2Ty);
|
||||
Value hb11 = bitcast(extract_element(hb, i32_val(3)), f16x2Ty);
|
||||
if (isBRow)
|
||||
ld(hbs, n + 1, K, hb10, hb11);
|
||||
else
|
||||
ld(hbs, n, K + 4, hb10, hb11);
|
||||
}
|
||||
};
|
||||
|
||||
unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
if (!hbs.count({n, k}))
|
||||
loadB(n, k);
|
||||
}
|
||||
|
||||
SmallVector<Value> elems;
|
||||
for (auto &item : hbs) { // has is a map, the key should be ordered.
|
||||
elems.push_back(item.second.first);
|
||||
elems.push_back(item.second.second);
|
||||
}
|
||||
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
|
||||
Type resTy = struct_ty(SmallVector<Type>(elems.size(), fp16x2Ty));
|
||||
Value res = getStructFromElements(loc, elems, rewriter, resTy);
|
||||
return res;
|
||||
}
|
||||
|
||||
Value DotOpMmaV1ConversionHelper::loadC(
|
||||
Value tensor, Value llTensor, ConversionPatternRewriter &rewriter) const {
|
||||
return llTensor;
|
||||
}
|
||||
|
||||
std::tuple<Value, Value, Value, Value>
|
||||
DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow,
|
||||
bool isBRow, ArrayRef<int> fpw,
|
||||
ArrayRef<int> spw, ArrayRef<int> rep,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) const {
|
||||
auto *ctx = rewriter.getContext();
|
||||
Value _1 = i32_val(1);
|
||||
Value _3 = i32_val(3);
|
||||
Value _4 = i32_val(4);
|
||||
Value _16 = i32_val(16);
|
||||
Value _32 = i32_val(32);
|
||||
|
||||
Value lane = urem(threadId, _32);
|
||||
Value warp = udiv(threadId, _32);
|
||||
|
||||
// warp offset
|
||||
Value warp0 = urem(warp, i32_val(wpt[0]));
|
||||
Value warp12 = udiv(warp, i32_val(wpt[0]));
|
||||
Value warp1 = urem(warp12, i32_val(wpt[1]));
|
||||
Value warpMOff = mul(warp0, i32_val(spw[0]));
|
||||
Value warpNOff = mul(warp1, i32_val(spw[1]));
|
||||
// Quad offset
|
||||
Value quadMOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[0]));
|
||||
Value quadNOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[1]));
|
||||
// Pair offset
|
||||
Value pairMOff = udiv(urem(lane, _16), _4);
|
||||
pairMOff = urem(pairMOff, i32_val(fpw[0]));
|
||||
pairMOff = mul(pairMOff, _4);
|
||||
Value pairNOff = udiv(urem(lane, _16), _4);
|
||||
pairNOff = udiv(pairNOff, i32_val(fpw[0]));
|
||||
pairNOff = urem(pairNOff, i32_val(fpw[1]));
|
||||
pairNOff = mul(pairNOff, _4);
|
||||
// scale
|
||||
pairMOff = mul(pairMOff, i32_val(rep[0] / 2));
|
||||
quadMOff = mul(quadMOff, i32_val(rep[0] / 2));
|
||||
pairNOff = mul(pairNOff, i32_val(rep[1] / 2));
|
||||
quadNOff = mul(quadNOff, i32_val(rep[1] / 2));
|
||||
// Quad pair offset
|
||||
Value laneMOff = add(pairMOff, quadMOff);
|
||||
Value laneNOff = add(pairNOff, quadNOff);
|
||||
// A offset
|
||||
Value offsetAM = add(warpMOff, laneMOff);
|
||||
Value offsetAK = and_(lane, _3);
|
||||
// B offset
|
||||
Value offsetBN = add(warpNOff, laneNOff);
|
||||
Value offsetBK = and_(lane, _3);
|
||||
// i indices
|
||||
Value offsetCM = add(and_(lane, _1), offsetAM);
|
||||
if (isARow) {
|
||||
offsetAM = add(offsetAM, urem(threadId, _4));
|
||||
offsetAK = i32_val(0);
|
||||
}
|
||||
if (!isBRow) {
|
||||
offsetBN = add(offsetBN, urem(threadId, _4));
|
||||
offsetBK = i32_val(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(offsetAM, offsetAK, offsetBN, offsetBK);
|
||||
}
|
||||
|
||||
DotOpMmaV1ConversionHelper::ValueTable
|
||||
DotOpMmaV1ConversionHelper::extractLoadedOperand(
|
||||
Value llStruct, int n0, int n1, ConversionPatternRewriter &rewriter) const {
|
||||
ValueTable rcds;
|
||||
SmallVector<Value> elems =
|
||||
ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct(
|
||||
llStruct.getLoc(), llStruct, rewriter);
|
||||
|
||||
int offset = 0;
|
||||
for (int i = 0; i < n0; ++i)
|
||||
for (int k = 0; k < n1; k += 4) {
|
||||
rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]);
|
||||
offset += 2;
|
||||
}
|
||||
|
||||
return rcds;
|
||||
}
|
||||
|
||||
/// ====================== mma codegen end ============================
|
||||
|
||||
Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
|
||||
@@ -3579,9 +4102,10 @@ Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
|
||||
TypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
if (layout.getVersion() == 2) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto [repM, repN] = DotOpConversionHelper::getRepMN(tensorTy);
|
||||
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy);
|
||||
size_t fcSize = 4 * repM * repN;
|
||||
|
||||
auto structTy = LLVM::LLVMStructType::getLiteral(
|
||||
@@ -3589,6 +4113,18 @@ Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
|
||||
return getStructFromElements(loc, SmallVector<Value>(fcSize, constVal),
|
||||
rewriter, structTy);
|
||||
}
|
||||
if (layout.getVersion() == 1) {
|
||||
DotOpMmaV1ConversionHelper helper(layout);
|
||||
int repM = helper.getRepM(shape[0]);
|
||||
int repN = helper.getRepN(shape[1]);
|
||||
// According to mma layout of v1, each thread process 8 elements.
|
||||
int elems = 8 * repM * repN;
|
||||
|
||||
auto structTy = LLVM::LLVMStructType::getLiteral(
|
||||
rewriter.getContext(), SmallVector<Type>(elems, elemType));
|
||||
return getStructFromElements(loc, SmallVector<Value>(elems, constVal),
|
||||
rewriter, structTy);
|
||||
}
|
||||
|
||||
assert(false && "Unsupported mma layout found");
|
||||
}
|
||||
@@ -3620,6 +4156,7 @@ public:
|
||||
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
||||
auto ctx = type.getContext();
|
||||
Attribute layout = type.getEncoding();
|
||||
auto shape = type.getShape();
|
||||
if (layout &&
|
||||
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
||||
layout.isa<MmaEncodingAttr>())) {
|
||||
@@ -3632,12 +4169,21 @@ public:
|
||||
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
||||
} else if (auto mmaLayout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
auto [repM, repN] = DotOpConversionHelper::getRepMN(type);
|
||||
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type);
|
||||
size_t fcSize = 4 * repM * repN;
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(fcSize, type.getElementType()));
|
||||
}
|
||||
|
||||
if (mmaLayout.getVersion() == 1) {
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
int repM = helper.getRepM(shape[0]);
|
||||
int repN = helper.getRepN(shape[1]);
|
||||
int elems = 8 * repM * repN;
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems, type.getElementType()));
|
||||
}
|
||||
|
||||
llvm::errs()
|
||||
<< "Unexpected mma layout detected in TritonToLLVMTypeConverter";
|
||||
return llvm::None;
|
||||
@@ -3645,9 +4191,9 @@ public:
|
||||
} else if (auto dot_op_layout =
|
||||
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
|
||||
auto mmaLayout = dot_op_layout.getParent().cast<MmaEncodingAttr>();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
Type elemTy = type.getElementType();
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
Type elemTy = type.getElementType();
|
||||
|
||||
if (dot_op_layout.getOpIdx() == 0) { // $a
|
||||
int elems =
|
||||
@@ -3660,8 +4206,22 @@ public:
|
||||
int elems =
|
||||
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems, x2Ty));
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
}
|
||||
|
||||
if (mmaLayout.getVersion() == 1) {
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
|
||||
if (dot_op_layout.getOpIdx() == 0) { // $a
|
||||
int elems = helper.numElemsPerThreadA(type);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
if (dot_op_layout.getOpIdx() == 1) { // $b
|
||||
int elems = helper.numElemsPerThreadB(type);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -270,10 +270,23 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
size_t rank = shape.size();
|
||||
assert(rank == 2 && "Unexpected rank of mma layout");
|
||||
assert(getVersion() == 2 && "mmaLayout version = 1 is not implemented yet");
|
||||
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
||||
return elemsCol * elemsRow;
|
||||
assert((getVersion() == 1 || getVersion() == 2) &&
|
||||
"Only version 1 and 2 is supported");
|
||||
|
||||
int res = 0;
|
||||
if (getVersion() == 1) {
|
||||
unsigned mmasRow = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]);
|
||||
unsigned mmasCol = ceil<unsigned>(shape[1], 16 * getWarpsPerCTA()[1]);
|
||||
// Each warp-level mma884 will perform a m16xn16xk4 mma, thus get a m16xn16
|
||||
// matrix as result.
|
||||
res = mmasRow * mmasCol * (16 * 16 / 32);
|
||||
} else if (getVersion() == 2) {
|
||||
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
||||
res = elemsCol * elemsRow;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
|
@@ -738,3 +738,28 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 2]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
|
||||
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
%a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a>
|
||||
%b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b>
|
||||
|
||||
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma>
|
||||
// TODO[goostavz]: uncomment the following lines after convert_layout[mma<v1> -> blocked] is ready.
|
||||
// %38 = triton_gpu.convert_layout %28 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked>
|
||||
// %30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>, #blocked>
|
||||
// %36 = tt.broadcast %30 : (tensor<128x1x!tt.ptr<f32>, #blocked>) -> tensor<128x256x!tt.ptr<f32>, #blocked>
|
||||
// tt.store %36, %38 : tensor<128x256xf32, #blocked>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user