[Triton-MLIR][BACKEND] Fix wpt overflow issue in mma v2 (#904)
This PR 1. Fix wpt overflow issue in mma v2 2. Refine transpose logic
This commit is contained in:
@@ -2239,16 +2239,16 @@ struct AllocTensorOpConversion
|
|||||||
smemBase = bitcast(smemBase, elemPtrTy);
|
smemBase = bitcast(smemBase, elemPtrTy);
|
||||||
auto order = resultTy.getEncoding().cast<SharedEncodingAttr>().getOrder();
|
auto order = resultTy.getEncoding().cast<SharedEncodingAttr>().getOrder();
|
||||||
// workaround for 3D tensors
|
// workaround for 3D tensors
|
||||||
// TODO: We need to modify the pipeline pass to give a proper shared encoding to 3D tensors
|
// TODO: We need to modify the pipeline pass to give a proper shared
|
||||||
|
// encoding to 3D tensors
|
||||||
SmallVector<unsigned> newOrder;
|
SmallVector<unsigned> newOrder;
|
||||||
if (resultTy.getShape().size() == 3)
|
if (resultTy.getShape().size() == 3)
|
||||||
newOrder = {1 + order[0], 1 + order[1], 0};
|
newOrder = {1 + order[0], 1 + order[1], 0};
|
||||||
else
|
else
|
||||||
newOrder = SmallVector<unsigned>(order.begin(), order.end());
|
newOrder = SmallVector<unsigned>(order.begin(), order.end());
|
||||||
|
|
||||||
|
auto smemObj = SharedMemoryObject(smemBase, resultTy.getShape(), newOrder,
|
||||||
auto smemObj =
|
loc, rewriter);
|
||||||
SharedMemoryObject(smemBase, resultTy.getShape(), newOrder, loc, rewriter);
|
|
||||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||||
rewriter.replaceOp(op, retVal);
|
rewriter.replaceOp(op, retVal);
|
||||||
return success();
|
return success();
|
||||||
@@ -2882,6 +2882,8 @@ private:
|
|||||||
SmallVector<Value> multiDimWarpId(2);
|
SmallVector<Value> multiDimWarpId(2);
|
||||||
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||||
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||||
|
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
|
||||||
|
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8));
|
||||||
Value four = idx_val(4);
|
Value four = idx_val(4);
|
||||||
Value mmaGrpId = udiv(laneId, four);
|
Value mmaGrpId = udiv(laneId, four);
|
||||||
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8));
|
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8));
|
||||||
@@ -3661,7 +3663,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
|||||||
|
|
||||||
// Here we assume the DotOp's operands always comes from shared memory.
|
// Here we assume the DotOp's operands always comes from shared memory.
|
||||||
auto AShape = A.getType().cast<RankedTensorType>().getShape();
|
auto AShape = A.getType().cast<RankedTensorType>().getShape();
|
||||||
size_t reduceAxis = 1;
|
size_t reduceAxis = op.transA() ? 0 : 1;
|
||||||
unsigned K = AShape[reduceAxis];
|
unsigned K = AShape[reduceAxis];
|
||||||
bool isOuter = K == 1;
|
bool isOuter = K == 1;
|
||||||
|
|
||||||
@@ -4124,8 +4126,9 @@ private:
|
|||||||
struct MMA16816ConversionHelper {
|
struct MMA16816ConversionHelper {
|
||||||
MmaEncodingAttr mmaLayout;
|
MmaEncodingAttr mmaLayout;
|
||||||
ArrayRef<unsigned int> wpt;
|
ArrayRef<unsigned int> wpt;
|
||||||
|
SmallVector<unsigned int> properWpt;
|
||||||
|
|
||||||
Value thread, lane, warp, warpMN, warpN, warpM;
|
Value thread, lane, warp;
|
||||||
|
|
||||||
DotOpMmaV2ConversionHelper helper;
|
DotOpMmaV2ConversionHelper helper;
|
||||||
ConversionPatternRewriter &rewriter;
|
ConversionPatternRewriter &rewriter;
|
||||||
@@ -4135,23 +4138,34 @@ struct MMA16816ConversionHelper {
|
|||||||
|
|
||||||
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
|
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
|
||||||
|
|
||||||
MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread,
|
// dotOperand: type of either one operand of dotOp.
|
||||||
ConversionPatternRewriter &rewriter,
|
MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout,
|
||||||
|
Value thread, ConversionPatternRewriter &rewriter,
|
||||||
TypeConverter *typeConverter, Location loc)
|
TypeConverter *typeConverter, Location loc)
|
||||||
: mmaLayout(mmaLayout), thread(thread), helper(mmaLayout),
|
: mmaLayout(mmaLayout), thread(thread), helper(mmaLayout),
|
||||||
rewriter(rewriter), typeConverter(typeConverter), loc(loc),
|
rewriter(rewriter), typeConverter(typeConverter), loc(loc),
|
||||||
ctx(mmaLayout.getContext()) {
|
ctx(mmaLayout.getContext()), wpt(mmaLayout.getWarpsPerCTA()) {
|
||||||
wpt = mmaLayout.getWarpsPerCTA();
|
helper.deduceMmaType(dotOperand);
|
||||||
|
|
||||||
Value _32 = i32_val(32);
|
Value _32 = i32_val(32);
|
||||||
lane = urem(thread, _32);
|
lane = urem(thread, _32);
|
||||||
warp = udiv(thread, _32);
|
warp = udiv(thread, _32);
|
||||||
warpMN = udiv(warp, i32_val(wpt[0]));
|
|
||||||
warpM = urem(warp, i32_val(wpt[0]));
|
|
||||||
warpN = urem(warpMN, i32_val(wpt[1]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the mmaInstrShape from either $a or $b.
|
// Get a warpId for M axis.
|
||||||
|
Value getWarpM(int M) const {
|
||||||
|
auto matShape = helper.getMmaMatShape();
|
||||||
|
return urem(urem(warp, i32_val(wpt[0])), i32_val(M / matShape[0]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a warpId for N axis.
|
||||||
|
Value getWarpN(int N) const {
|
||||||
|
auto matShape = helper.getMmaMatShape();
|
||||||
|
Value warpMN = udiv(warp, i32_val(wpt[0]));
|
||||||
|
return urem(urem(warpMN, i32_val(wpt[1])), i32_val(N / matShape[1]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the mmaInstrShape deducing either from $a or $b.
|
||||||
std::tuple<int, int, int> getMmaInstrShape(Type operand) const {
|
std::tuple<int, int, int> getMmaInstrShape(Type operand) const {
|
||||||
helper.deduceMmaType(operand);
|
helper.deduceMmaType(operand);
|
||||||
auto mmaInstrShape = helper.getMmaInstrShape();
|
auto mmaInstrShape = helper.getMmaInstrShape();
|
||||||
@@ -4161,6 +4175,7 @@ struct MMA16816ConversionHelper {
|
|||||||
return std::make_tuple(mmaInstrM, mmaInstrN, mmaInstrK);
|
return std::make_tuple(mmaInstrM, mmaInstrN, mmaInstrK);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the mmaMatShape deducing either from $a or $b.
|
||||||
std::tuple<int, int, int> getMmaMatShape(Type operand) const {
|
std::tuple<int, int, int> getMmaMatShape(Type operand) const {
|
||||||
helper.deduceMmaType(operand);
|
helper.deduceMmaType(operand);
|
||||||
auto matShape = helper.getMmaMatShape();
|
auto matShape = helper.getMmaMatShape();
|
||||||
@@ -4210,28 +4225,28 @@ struct MMA16816ConversionHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get number of elements per thread for $a operand.
|
// Get number of elements per thread for $a operand.
|
||||||
static size_t getANumElemsPerThread(RankedTensorType operand,
|
static size_t getANumElemsPerThread(RankedTensorType operand, int wpt) {
|
||||||
ArrayRef<unsigned> wpt) {
|
|
||||||
auto shape = operand.getShape();
|
auto shape = operand.getShape();
|
||||||
int repM = getNumRepM(operand, shape[0], wpt[0]);
|
int repM = getNumRepM(operand, shape[0], wpt);
|
||||||
int repK = getNumRepK_(operand, shape[1]);
|
int repK = getNumRepK_(operand, shape[1]);
|
||||||
return 4 * repM * repK;
|
return 4 * repM * repK;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get number of elements per thread for $b operand.
|
// Get number of elements per thread for $b operand.
|
||||||
static size_t getBNumElemsPerThread(RankedTensorType operand,
|
static size_t getBNumElemsPerThread(RankedTensorType operand, int wpt) {
|
||||||
ArrayRef<unsigned> wpt) {
|
|
||||||
auto shape = operand.getShape();
|
auto shape = operand.getShape();
|
||||||
int repK = getNumRepK_(operand, shape[0]);
|
int repK = getNumRepK_(operand, shape[0]);
|
||||||
int repN = getNumRepN(operand, shape[1], wpt[1]);
|
int repN = getNumRepN(operand, shape[1], wpt);
|
||||||
return 4 * std::max(repN / 2, 1) * repK;
|
return 4 * std::max(repN / 2, 1) * repK;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||||
Value loadA(Value tensor, const SharedMemoryObject &smemObj) const {
|
Value loadA(Value tensor, const SharedMemoryObject &smemObj) const {
|
||||||
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
auto shape = aTensorTy.getShape();
|
auto layout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||||
|
|
||||||
|
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
|
||||||
|
aTensorTy.getShape().end());
|
||||||
ValueTable ha;
|
ValueTable ha;
|
||||||
std::function<void(int, int)> loadFn;
|
std::function<void(int, int)> loadFn;
|
||||||
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
|
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
|
||||||
@@ -4241,6 +4256,7 @@ struct MMA16816ConversionHelper {
|
|||||||
int numRepK = getNumRepK(aTensorTy, shape[1]);
|
int numRepK = getNumRepK(aTensorTy, shape[1]);
|
||||||
|
|
||||||
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
||||||
|
Value warpM = getWarpM(shape[0]);
|
||||||
// load from smem
|
// load from smem
|
||||||
loadFn = getLoadMatrixFn(
|
loadFn = getLoadMatrixFn(
|
||||||
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||||
@@ -4268,12 +4284,17 @@ struct MMA16816ConversionHelper {
|
|||||||
Value loadB(Value tensor, const SharedMemoryObject &smemObj) {
|
Value loadB(Value tensor, const SharedMemoryObject &smemObj) {
|
||||||
ValueTable hb;
|
ValueTable hb;
|
||||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
auto shape = tensorTy.getShape();
|
auto layout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||||
|
|
||||||
|
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
|
||||||
|
tensorTy.getShape().end());
|
||||||
|
|
||||||
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(tensorTy);
|
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(tensorTy);
|
||||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(tensorTy);
|
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(tensorTy);
|
||||||
int numRepK = getNumRepK(tensorTy, shape[0]);
|
int numRepK = getNumRepK(tensorTy, shape[0]);
|
||||||
int numRepN = getNumRepN(tensorTy, shape[1]);
|
int numRepN = getNumRepN(tensorTy, shape[1]);
|
||||||
|
|
||||||
|
Value warpN = getWarpN(shape[1]);
|
||||||
auto loadFn = getLoadMatrixFn(
|
auto loadFn = getLoadMatrixFn(
|
||||||
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/,
|
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/,
|
||||||
@@ -4319,7 +4340,11 @@ struct MMA16816ConversionHelper {
|
|||||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
auto aShape = aTensorTy.getShape();
|
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
|
||||||
|
aTensorTy.getShape().end());
|
||||||
|
if (op.transA())
|
||||||
|
std::swap(aShape[0], aShape[1]);
|
||||||
|
|
||||||
auto dShape = dTensorTy.getShape();
|
auto dShape = dTensorTy.getShape();
|
||||||
|
|
||||||
// shape / shape_per_cta
|
// shape / shape_per_cta
|
||||||
@@ -4602,9 +4627,9 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
|||||||
Value res;
|
Value res;
|
||||||
|
|
||||||
if (!isOuter && mmaLayout.getVersion() == 2 && isHMMA) { // tensor core v2
|
if (!isOuter && mmaLayout.getVersion() == 2 && isHMMA) { // tensor core v2
|
||||||
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
MMA16816ConversionHelper mmaHelper(src.getType(), mmaLayout,
|
||||||
rewriter, getTypeConverter(),
|
getThreadId(rewriter, loc), rewriter,
|
||||||
op.getLoc());
|
getTypeConverter(), op.getLoc());
|
||||||
|
|
||||||
if (dotOperandLayout.getOpIdx() == 0) {
|
if (dotOperandLayout.getOpIdx() == 0) {
|
||||||
// operand $a
|
// operand $a
|
||||||
@@ -4695,12 +4720,15 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
|
|||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getEncoding()
|
.getEncoding()
|
||||||
.cast<MmaEncodingAttr>();
|
.cast<MmaEncodingAttr>();
|
||||||
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
|
||||||
rewriter, getTypeConverter(), loc);
|
|
||||||
|
|
||||||
Value A = op.a();
|
Value A = op.a();
|
||||||
Value B = op.b();
|
Value B = op.b();
|
||||||
Value C = op.c();
|
Value C = op.c();
|
||||||
|
|
||||||
|
MMA16816ConversionHelper mmaHelper(A.getType(), mmaLayout,
|
||||||
|
getThreadId(rewriter, loc), rewriter,
|
||||||
|
getTypeConverter(), loc);
|
||||||
|
|
||||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
@@ -5532,13 +5560,13 @@ public:
|
|||||||
if (mmaLayout.getVersion() == 2) {
|
if (mmaLayout.getVersion() == 2) {
|
||||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||||
int elems =
|
int elems =
|
||||||
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt);
|
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt[0]);
|
||||||
return LLVM::LLVMStructType::getLiteral(
|
return LLVM::LLVMStructType::getLiteral(
|
||||||
ctx, SmallVector<Type>(elems, vecTy));
|
ctx, SmallVector<Type>(elems, vecTy));
|
||||||
}
|
}
|
||||||
if (dotOpLayout.getOpIdx() == 1) { // $b
|
if (dotOpLayout.getOpIdx() == 1) { // $b
|
||||||
int elems =
|
int elems =
|
||||||
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt);
|
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt[1]);
|
||||||
return struct_ty(SmallVector<Type>(elems, vecTy));
|
return struct_ty(SmallVector<Type>(elems, vecTy));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -6159,10 +6187,9 @@ private:
|
|||||||
if (srcBlocked && dstDotOp) {
|
if (srcBlocked && dstDotOp) {
|
||||||
auto tmpType = RankedTensorType::get(
|
auto tmpType = RankedTensorType::get(
|
||||||
dstType.getShape(), dstType.getElementType(),
|
dstType.getShape(), dstType.getElementType(),
|
||||||
triton::gpu::SharedEncodingAttr::get(mod.getContext(), dstDotOp,
|
triton::gpu::SharedEncodingAttr::get(
|
||||||
srcType.getShape(),
|
mod.getContext(), dstDotOp, srcType.getShape(),
|
||||||
getOrder(srcBlocked),
|
getOrder(srcBlocked), srcType.getElementType()));
|
||||||
srcType.getElementType()));
|
|
||||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||||
|
@@ -27,8 +27,6 @@ def matmul_no_scf_kernel(
|
|||||||
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
||||||
tl.store(c_ptrs, c)
|
tl.store(c_ptrs, c)
|
||||||
|
|
||||||
# TODO: num_warps could only be 4 for now
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
|
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
|
||||||
(shape, num_warps, trans_a, trans_b)
|
(shape, num_warps, trans_a, trans_b)
|
||||||
@@ -172,6 +170,7 @@ def get_proper_err(a, b, golden):
|
|||||||
# Non-forloop
|
# Non-forloop
|
||||||
[64, 32, 64, 4, 64, 32, 64, False, False],
|
[64, 32, 64, 4, 64, 32, 64, False, False],
|
||||||
[128, 64, 128, 4, 128, 64, 128, False, False],
|
[128, 64, 128, 4, 128, 64, 128, False, False],
|
||||||
|
[16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue
|
||||||
# K-Forloop
|
# K-Forloop
|
||||||
[64, 32, 128, 4, 64, 32, 64, False, False],
|
[64, 32, 128, 4, 64, 32, 64, False, False],
|
||||||
[128, 16, 128, 4, 128, 16, 32, False, False],
|
[128, 16, 128, 4, 128, 16, 32, False, False],
|
||||||
@@ -186,6 +185,7 @@ def get_proper_err(a, b, golden):
|
|||||||
[128, 256, 128, 4, 128, 256, 32, False, False],
|
[128, 256, 128, 4, 128, 256, 32, False, False],
|
||||||
[256, 128, 64, 4, 256, 128, 16, False, False],
|
[256, 128, 64, 4, 256, 128, 16, False, False],
|
||||||
[128, 64, 128, 4, 128, 64, 32, False, False],
|
[128, 64, 128, 4, 128, 64, 32, False, False],
|
||||||
|
# [16, 16, 64, 4, 16, 16, 16, False, False], # TODO failed due to pipeline pass
|
||||||
# trans
|
# trans
|
||||||
[128, 64, 128, 4, 128, 64, 32, True, False],
|
[128, 64, 128, 4, 128, 64, 32, True, False],
|
||||||
[128, 64, 128, 4, 128, 64, 32, False, True],
|
[128, 64, 128, 4, 128, 64, 32, False, True],
|
||||||
|
Reference in New Issue
Block a user