[Triton-MLIR][BACKEND] Fix wpt overflow issue in mma v2 (#904)
This PR 1. Fix wpt overflow issue in mma v2 2. Refine transpose logic
This commit is contained in:
@@ -2239,16 +2239,16 @@ struct AllocTensorOpConversion
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
auto order = resultTy.getEncoding().cast<SharedEncodingAttr>().getOrder();
|
||||
// workaround for 3D tensors
|
||||
// TODO: We need to modify the pipeline pass to give a proper shared encoding to 3D tensors
|
||||
// TODO: We need to modify the pipeline pass to give a proper shared
|
||||
// encoding to 3D tensors
|
||||
SmallVector<unsigned> newOrder;
|
||||
if (resultTy.getShape().size() == 3)
|
||||
if (resultTy.getShape().size() == 3)
|
||||
newOrder = {1 + order[0], 1 + order[1], 0};
|
||||
else
|
||||
newOrder = SmallVector<unsigned>(order.begin(), order.end());
|
||||
|
||||
|
||||
auto smemObj =
|
||||
SharedMemoryObject(smemBase, resultTy.getShape(), newOrder, loc, rewriter);
|
||||
auto smemObj = SharedMemoryObject(smemBase, resultTy.getShape(), newOrder,
|
||||
loc, rewriter);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
@@ -2882,6 +2882,8 @@ private:
|
||||
SmallVector<Value> multiDimWarpId(2);
|
||||
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
|
||||
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8));
|
||||
Value four = idx_val(4);
|
||||
Value mmaGrpId = udiv(laneId, four);
|
||||
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8));
|
||||
@@ -3661,7 +3663,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
|
||||
// Here we assume the DotOp's operands always comes from shared memory.
|
||||
auto AShape = A.getType().cast<RankedTensorType>().getShape();
|
||||
size_t reduceAxis = 1;
|
||||
size_t reduceAxis = op.transA() ? 0 : 1;
|
||||
unsigned K = AShape[reduceAxis];
|
||||
bool isOuter = K == 1;
|
||||
|
||||
@@ -4124,8 +4126,9 @@ private:
|
||||
struct MMA16816ConversionHelper {
|
||||
MmaEncodingAttr mmaLayout;
|
||||
ArrayRef<unsigned int> wpt;
|
||||
SmallVector<unsigned int> properWpt;
|
||||
|
||||
Value thread, lane, warp, warpMN, warpN, warpM;
|
||||
Value thread, lane, warp;
|
||||
|
||||
DotOpMmaV2ConversionHelper helper;
|
||||
ConversionPatternRewriter &rewriter;
|
||||
@@ -4135,23 +4138,34 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
|
||||
|
||||
MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
// dotOperand: type of either one operand of dotOp.
|
||||
MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout,
|
||||
Value thread, ConversionPatternRewriter &rewriter,
|
||||
TypeConverter *typeConverter, Location loc)
|
||||
: mmaLayout(mmaLayout), thread(thread), helper(mmaLayout),
|
||||
rewriter(rewriter), typeConverter(typeConverter), loc(loc),
|
||||
ctx(mmaLayout.getContext()) {
|
||||
wpt = mmaLayout.getWarpsPerCTA();
|
||||
ctx(mmaLayout.getContext()), wpt(mmaLayout.getWarpsPerCTA()) {
|
||||
helper.deduceMmaType(dotOperand);
|
||||
|
||||
Value _32 = i32_val(32);
|
||||
lane = urem(thread, _32);
|
||||
warp = udiv(thread, _32);
|
||||
warpMN = udiv(warp, i32_val(wpt[0]));
|
||||
warpM = urem(warp, i32_val(wpt[0]));
|
||||
warpN = urem(warpMN, i32_val(wpt[1]));
|
||||
}
|
||||
|
||||
// Get the mmaInstrShape from either $a or $b.
|
||||
// Get a warpId for M axis.
|
||||
Value getWarpM(int M) const {
|
||||
auto matShape = helper.getMmaMatShape();
|
||||
return urem(urem(warp, i32_val(wpt[0])), i32_val(M / matShape[0]));
|
||||
}
|
||||
|
||||
// Get a warpId for N axis.
|
||||
Value getWarpN(int N) const {
|
||||
auto matShape = helper.getMmaMatShape();
|
||||
Value warpMN = udiv(warp, i32_val(wpt[0]));
|
||||
return urem(urem(warpMN, i32_val(wpt[1])), i32_val(N / matShape[1]));
|
||||
}
|
||||
|
||||
// Get the mmaInstrShape deducing either from $a or $b.
|
||||
std::tuple<int, int, int> getMmaInstrShape(Type operand) const {
|
||||
helper.deduceMmaType(operand);
|
||||
auto mmaInstrShape = helper.getMmaInstrShape();
|
||||
@@ -4161,6 +4175,7 @@ struct MMA16816ConversionHelper {
|
||||
return std::make_tuple(mmaInstrM, mmaInstrN, mmaInstrK);
|
||||
}
|
||||
|
||||
// Get the mmaMatShape deducing either from $a or $b.
|
||||
std::tuple<int, int, int> getMmaMatShape(Type operand) const {
|
||||
helper.deduceMmaType(operand);
|
||||
auto matShape = helper.getMmaMatShape();
|
||||
@@ -4210,28 +4225,28 @@ struct MMA16816ConversionHelper {
|
||||
}
|
||||
|
||||
// Get number of elements per thread for $a operand.
|
||||
static size_t getANumElemsPerThread(RankedTensorType operand,
|
||||
ArrayRef<unsigned> wpt) {
|
||||
static size_t getANumElemsPerThread(RankedTensorType operand, int wpt) {
|
||||
auto shape = operand.getShape();
|
||||
int repM = getNumRepM(operand, shape[0], wpt[0]);
|
||||
int repM = getNumRepM(operand, shape[0], wpt);
|
||||
int repK = getNumRepK_(operand, shape[1]);
|
||||
return 4 * repM * repK;
|
||||
}
|
||||
|
||||
// Get number of elements per thread for $b operand.
|
||||
static size_t getBNumElemsPerThread(RankedTensorType operand,
|
||||
ArrayRef<unsigned> wpt) {
|
||||
static size_t getBNumElemsPerThread(RankedTensorType operand, int wpt) {
|
||||
auto shape = operand.getShape();
|
||||
int repK = getNumRepK_(operand, shape[0]);
|
||||
int repN = getNumRepN(operand, shape[1], wpt[1]);
|
||||
int repN = getNumRepN(operand, shape[1], wpt);
|
||||
return 4 * std::max(repN / 2, 1) * repK;
|
||||
}
|
||||
|
||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||
Value loadA(Value tensor, const SharedMemoryObject &smemObj) const {
|
||||
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shape = aTensorTy.getShape();
|
||||
auto layout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
|
||||
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
|
||||
aTensorTy.getShape().end());
|
||||
ValueTable ha;
|
||||
std::function<void(int, int)> loadFn;
|
||||
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
|
||||
@@ -4241,6 +4256,7 @@ struct MMA16816ConversionHelper {
|
||||
int numRepK = getNumRepK(aTensorTy, shape[1]);
|
||||
|
||||
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
||||
Value warpM = getWarpM(shape[0]);
|
||||
// load from smem
|
||||
loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||
@@ -4268,12 +4284,17 @@ struct MMA16816ConversionHelper {
|
||||
Value loadB(Value tensor, const SharedMemoryObject &smemObj) {
|
||||
ValueTable hb;
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
auto layout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
|
||||
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
|
||||
tensorTy.getShape().end());
|
||||
|
||||
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(tensorTy);
|
||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(tensorTy);
|
||||
int numRepK = getNumRepK(tensorTy, shape[0]);
|
||||
int numRepN = getNumRepN(tensorTy, shape[1]);
|
||||
|
||||
Value warpN = getWarpN(shape[1]);
|
||||
auto loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/,
|
||||
@@ -4319,7 +4340,11 @@ struct MMA16816ConversionHelper {
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
|
||||
auto aShape = aTensorTy.getShape();
|
||||
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
|
||||
aTensorTy.getShape().end());
|
||||
if (op.transA())
|
||||
std::swap(aShape[0], aShape[1]);
|
||||
|
||||
auto dShape = dTensorTy.getShape();
|
||||
|
||||
// shape / shape_per_cta
|
||||
@@ -4602,9 +4627,9 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
||||
Value res;
|
||||
|
||||
if (!isOuter && mmaLayout.getVersion() == 2 && isHMMA) { // tensor core v2
|
||||
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
||||
rewriter, getTypeConverter(),
|
||||
op.getLoc());
|
||||
MMA16816ConversionHelper mmaHelper(src.getType(), mmaLayout,
|
||||
getThreadId(rewriter, loc), rewriter,
|
||||
getTypeConverter(), op.getLoc());
|
||||
|
||||
if (dotOperandLayout.getOpIdx() == 0) {
|
||||
// operand $a
|
||||
@@ -4695,12 +4720,15 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<MmaEncodingAttr>();
|
||||
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
||||
rewriter, getTypeConverter(), loc);
|
||||
|
||||
Value A = op.a();
|
||||
Value B = op.b();
|
||||
Value C = op.c();
|
||||
|
||||
MMA16816ConversionHelper mmaHelper(A.getType(), mmaLayout,
|
||||
getThreadId(rewriter, loc), rewriter,
|
||||
getTypeConverter(), loc);
|
||||
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
|
||||
@@ -5532,13 +5560,13 @@ public:
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
int elems =
|
||||
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt);
|
||||
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt[0]);
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems, vecTy));
|
||||
}
|
||||
if (dotOpLayout.getOpIdx() == 1) { // $b
|
||||
int elems =
|
||||
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt);
|
||||
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt[1]);
|
||||
return struct_ty(SmallVector<Type>(elems, vecTy));
|
||||
}
|
||||
}
|
||||
@@ -6159,10 +6187,9 @@ private:
|
||||
if (srcBlocked && dstDotOp) {
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(mod.getContext(), dstDotOp,
|
||||
srcType.getShape(),
|
||||
getOrder(srcBlocked),
|
||||
srcType.getElementType()));
|
||||
triton::gpu::SharedEncodingAttr::get(
|
||||
mod.getContext(), dstDotOp, srcType.getShape(),
|
||||
getOrder(srcBlocked), srcType.getElementType()));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
|
Reference in New Issue
Block a user