[Triton-MLIR][BACKEND] Fix wpt overflow issue in mma v2 (#904)

This PR

1. Fix wpt overflow issue in mma v2
2. Refine transpose logic
This commit is contained in:
Yan Chunwei
2022-11-23 11:27:15 +08:00
committed by GitHub
parent 07786dc932
commit 037f9efa95
2 changed files with 64 additions and 37 deletions

View File

@@ -2239,16 +2239,16 @@ struct AllocTensorOpConversion
smemBase = bitcast(smemBase, elemPtrTy); smemBase = bitcast(smemBase, elemPtrTy);
auto order = resultTy.getEncoding().cast<SharedEncodingAttr>().getOrder(); auto order = resultTy.getEncoding().cast<SharedEncodingAttr>().getOrder();
// workaround for 3D tensors // workaround for 3D tensors
// TODO: We need to modify the pipeline pass to give a proper shared encoding to 3D tensors // TODO: We need to modify the pipeline pass to give a proper shared
// encoding to 3D tensors
SmallVector<unsigned> newOrder; SmallVector<unsigned> newOrder;
if (resultTy.getShape().size() == 3) if (resultTy.getShape().size() == 3)
newOrder = {1 + order[0], 1 + order[1], 0}; newOrder = {1 + order[0], 1 + order[1], 0};
else else
newOrder = SmallVector<unsigned>(order.begin(), order.end()); newOrder = SmallVector<unsigned>(order.begin(), order.end());
auto smemObj = SharedMemoryObject(smemBase, resultTy.getShape(), newOrder,
auto smemObj = loc, rewriter);
SharedMemoryObject(smemBase, resultTy.getShape(), newOrder, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal); rewriter.replaceOp(op, retVal);
return success(); return success();
@@ -2882,6 +2882,8 @@ private:
SmallVector<Value> multiDimWarpId(2); SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8));
Value four = idx_val(4); Value four = idx_val(4);
Value mmaGrpId = udiv(laneId, four); Value mmaGrpId = udiv(laneId, four);
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8)); Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8));
@@ -3661,7 +3663,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
// Here we assume the DotOp's operands always comes from shared memory. // Here we assume the DotOp's operands always comes from shared memory.
auto AShape = A.getType().cast<RankedTensorType>().getShape(); auto AShape = A.getType().cast<RankedTensorType>().getShape();
size_t reduceAxis = 1; size_t reduceAxis = op.transA() ? 0 : 1;
unsigned K = AShape[reduceAxis]; unsigned K = AShape[reduceAxis];
bool isOuter = K == 1; bool isOuter = K == 1;
@@ -4124,8 +4126,9 @@ private:
struct MMA16816ConversionHelper { struct MMA16816ConversionHelper {
MmaEncodingAttr mmaLayout; MmaEncodingAttr mmaLayout;
ArrayRef<unsigned int> wpt; ArrayRef<unsigned int> wpt;
SmallVector<unsigned int> properWpt;
Value thread, lane, warp, warpMN, warpN, warpM; Value thread, lane, warp;
DotOpMmaV2ConversionHelper helper; DotOpMmaV2ConversionHelper helper;
ConversionPatternRewriter &rewriter; ConversionPatternRewriter &rewriter;
@@ -4135,23 +4138,34 @@ struct MMA16816ConversionHelper {
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>; using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread, // dotOperand: type of either one operand of dotOp.
ConversionPatternRewriter &rewriter, MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout,
Value thread, ConversionPatternRewriter &rewriter,
TypeConverter *typeConverter, Location loc) TypeConverter *typeConverter, Location loc)
: mmaLayout(mmaLayout), thread(thread), helper(mmaLayout), : mmaLayout(mmaLayout), thread(thread), helper(mmaLayout),
rewriter(rewriter), typeConverter(typeConverter), loc(loc), rewriter(rewriter), typeConverter(typeConverter), loc(loc),
ctx(mmaLayout.getContext()) { ctx(mmaLayout.getContext()), wpt(mmaLayout.getWarpsPerCTA()) {
wpt = mmaLayout.getWarpsPerCTA(); helper.deduceMmaType(dotOperand);
Value _32 = i32_val(32); Value _32 = i32_val(32);
lane = urem(thread, _32); lane = urem(thread, _32);
warp = udiv(thread, _32); warp = udiv(thread, _32);
warpMN = udiv(warp, i32_val(wpt[0]));
warpM = urem(warp, i32_val(wpt[0]));
warpN = urem(warpMN, i32_val(wpt[1]));
} }
// Get the mmaInstrShape from either $a or $b. // Get a warpId for M axis.
Value getWarpM(int M) const {
auto matShape = helper.getMmaMatShape();
return urem(urem(warp, i32_val(wpt[0])), i32_val(M / matShape[0]));
}
// Get a warpId for N axis.
Value getWarpN(int N) const {
auto matShape = helper.getMmaMatShape();
Value warpMN = udiv(warp, i32_val(wpt[0]));
return urem(urem(warpMN, i32_val(wpt[1])), i32_val(N / matShape[1]));
}
// Get the mmaInstrShape deducing either from $a or $b.
std::tuple<int, int, int> getMmaInstrShape(Type operand) const { std::tuple<int, int, int> getMmaInstrShape(Type operand) const {
helper.deduceMmaType(operand); helper.deduceMmaType(operand);
auto mmaInstrShape = helper.getMmaInstrShape(); auto mmaInstrShape = helper.getMmaInstrShape();
@@ -4161,6 +4175,7 @@ struct MMA16816ConversionHelper {
return std::make_tuple(mmaInstrM, mmaInstrN, mmaInstrK); return std::make_tuple(mmaInstrM, mmaInstrN, mmaInstrK);
} }
// Get the mmaMatShape deducing either from $a or $b.
std::tuple<int, int, int> getMmaMatShape(Type operand) const { std::tuple<int, int, int> getMmaMatShape(Type operand) const {
helper.deduceMmaType(operand); helper.deduceMmaType(operand);
auto matShape = helper.getMmaMatShape(); auto matShape = helper.getMmaMatShape();
@@ -4210,28 +4225,28 @@ struct MMA16816ConversionHelper {
} }
// Get number of elements per thread for $a operand. // Get number of elements per thread for $a operand.
static size_t getANumElemsPerThread(RankedTensorType operand, static size_t getANumElemsPerThread(RankedTensorType operand, int wpt) {
ArrayRef<unsigned> wpt) {
auto shape = operand.getShape(); auto shape = operand.getShape();
int repM = getNumRepM(operand, shape[0], wpt[0]); int repM = getNumRepM(operand, shape[0], wpt);
int repK = getNumRepK_(operand, shape[1]); int repK = getNumRepK_(operand, shape[1]);
return 4 * repM * repK; return 4 * repM * repK;
} }
// Get number of elements per thread for $b operand. // Get number of elements per thread for $b operand.
static size_t getBNumElemsPerThread(RankedTensorType operand, static size_t getBNumElemsPerThread(RankedTensorType operand, int wpt) {
ArrayRef<unsigned> wpt) {
auto shape = operand.getShape(); auto shape = operand.getShape();
int repK = getNumRepK_(operand, shape[0]); int repK = getNumRepK_(operand, shape[0]);
int repN = getNumRepN(operand, shape[1], wpt[1]); int repN = getNumRepN(operand, shape[1], wpt);
return 4 * std::max(repN / 2, 1) * repK; return 4 * std::max(repN / 2, 1) * repK;
} }
// Loading $a from smem to registers, returns a LLVM::Struct. // Loading $a from smem to registers, returns a LLVM::Struct.
Value loadA(Value tensor, const SharedMemoryObject &smemObj) const { Value loadA(Value tensor, const SharedMemoryObject &smemObj) const {
auto aTensorTy = tensor.getType().cast<RankedTensorType>(); auto aTensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = aTensorTy.getShape(); auto layout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
aTensorTy.getShape().end());
ValueTable ha; ValueTable ha;
std::function<void(int, int)> loadFn; std::function<void(int, int)> loadFn;
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy); auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
@@ -4241,6 +4256,7 @@ struct MMA16816ConversionHelper {
int numRepK = getNumRepK(aTensorTy, shape[1]); int numRepK = getNumRepK(aTensorTy, shape[1]);
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) { if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
Value warpM = getWarpM(shape[0]);
// load from smem // load from smem
loadFn = getLoadMatrixFn( loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
@@ -4268,12 +4284,17 @@ struct MMA16816ConversionHelper {
Value loadB(Value tensor, const SharedMemoryObject &smemObj) { Value loadB(Value tensor, const SharedMemoryObject &smemObj) {
ValueTable hb; ValueTable hb;
auto tensorTy = tensor.getType().cast<RankedTensorType>(); auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = tensorTy.getShape(); auto layout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
tensorTy.getShape().end());
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(tensorTy); auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(tensorTy);
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(tensorTy); auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(tensorTy);
int numRepK = getNumRepK(tensorTy, shape[0]); int numRepK = getNumRepK(tensorTy, shape[0]);
int numRepN = getNumRepN(tensorTy, shape[1]); int numRepN = getNumRepN(tensorTy, shape[1]);
Value warpN = getWarpN(shape[1]);
auto loadFn = getLoadMatrixFn( auto loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/, 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/,
@@ -4319,7 +4340,11 @@ struct MMA16816ConversionHelper {
auto aTensorTy = a.getType().cast<RankedTensorType>(); auto aTensorTy = a.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>(); auto dTensorTy = d.getType().cast<RankedTensorType>();
auto aShape = aTensorTy.getShape(); SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
aTensorTy.getShape().end());
if (op.transA())
std::swap(aShape[0], aShape[1]);
auto dShape = dTensorTy.getShape(); auto dShape = dTensorTy.getShape();
// shape / shape_per_cta // shape / shape_per_cta
@@ -4602,9 +4627,9 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
Value res; Value res;
if (!isOuter && mmaLayout.getVersion() == 2 && isHMMA) { // tensor core v2 if (!isOuter && mmaLayout.getVersion() == 2 && isHMMA) { // tensor core v2
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc), MMA16816ConversionHelper mmaHelper(src.getType(), mmaLayout,
rewriter, getTypeConverter(), getThreadId(rewriter, loc), rewriter,
op.getLoc()); getTypeConverter(), op.getLoc());
if (dotOperandLayout.getOpIdx() == 0) { if (dotOperandLayout.getOpIdx() == 0) {
// operand $a // operand $a
@@ -4695,12 +4720,15 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getEncoding() .getEncoding()
.cast<MmaEncodingAttr>(); .cast<MmaEncodingAttr>();
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
rewriter, getTypeConverter(), loc);
Value A = op.a(); Value A = op.a();
Value B = op.b(); Value B = op.b();
Value C = op.c(); Value C = op.c();
MMA16816ConversionHelper mmaHelper(A.getType(), mmaLayout,
getThreadId(rewriter, loc), rewriter,
getTypeConverter(), loc);
auto ATensorTy = A.getType().cast<RankedTensorType>(); auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>(); auto BTensorTy = B.getType().cast<RankedTensorType>();
@@ -5532,13 +5560,13 @@ public:
if (mmaLayout.getVersion() == 2) { if (mmaLayout.getVersion() == 2) {
if (dotOpLayout.getOpIdx() == 0) { // $a if (dotOpLayout.getOpIdx() == 0) { // $a
int elems = int elems =
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt); MMA16816ConversionHelper::getANumElemsPerThread(type, wpt[0]);
return LLVM::LLVMStructType::getLiteral( return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems, vecTy)); ctx, SmallVector<Type>(elems, vecTy));
} }
if (dotOpLayout.getOpIdx() == 1) { // $b if (dotOpLayout.getOpIdx() == 1) { // $b
int elems = int elems =
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt); MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt[1]);
return struct_ty(SmallVector<Type>(elems, vecTy)); return struct_ty(SmallVector<Type>(elems, vecTy));
} }
} }
@@ -6159,10 +6187,9 @@ private:
if (srcBlocked && dstDotOp) { if (srcBlocked && dstDotOp) {
auto tmpType = RankedTensorType::get( auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(mod.getContext(), dstDotOp, triton::gpu::SharedEncodingAttr::get(
srcType.getShape(), mod.getContext(), dstDotOp, srcType.getShape(),
getOrder(srcBlocked), getOrder(srcBlocked), srcType.getElementType()));
srcType.getElementType()));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>( auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand()); cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>( auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(

View File

@@ -27,8 +27,6 @@ def matmul_no_scf_kernel(
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, c) tl.store(c_ptrs, c)
# TODO: num_warps could only be 4 for now
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [ @pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
(shape, num_warps, trans_a, trans_b) (shape, num_warps, trans_a, trans_b)
@@ -172,6 +170,7 @@ def get_proper_err(a, b, golden):
# Non-forloop # Non-forloop
[64, 32, 64, 4, 64, 32, 64, False, False], [64, 32, 64, 4, 64, 32, 64, False, False],
[128, 64, 128, 4, 128, 64, 128, False, False], [128, 64, 128, 4, 128, 64, 128, False, False],
[16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue
# K-Forloop # K-Forloop
[64, 32, 128, 4, 64, 32, 64, False, False], [64, 32, 128, 4, 64, 32, 64, False, False],
[128, 16, 128, 4, 128, 16, 32, False, False], [128, 16, 128, 4, 128, 16, 32, False, False],
@@ -186,6 +185,7 @@ def get_proper_err(a, b, golden):
[128, 256, 128, 4, 128, 256, 32, False, False], [128, 256, 128, 4, 128, 256, 32, False, False],
[256, 128, 64, 4, 256, 128, 16, False, False], [256, 128, 64, 4, 256, 128, 16, False, False],
[128, 64, 128, 4, 128, 64, 32, False, False], [128, 64, 128, 4, 128, 64, 32, False, False],
# [16, 16, 64, 4, 16, 16, 16, False, False], # TODO failed due to pipeline pass
# trans # trans
[128, 64, 128, 4, 128, 64, 32, True, False], [128, 64, 128, 4, 128, 64, 32, True, False],
[128, 64, 128, 4, 128, 64, 32, False, True], [128, 64, 128, 4, 128, 64, 32, False, True],