[Triton-MLIR][Backend] Fix mma<v2> int8 precision error (#850)

Fix mma.16816 s8 precision error

Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
This commit is contained in:
Yan Chunwei
2022-11-09 12:23:43 +08:00
committed by GitHub
parent e517b58d59
commit de5b84c476
4 changed files with 146 additions and 42 deletions

View File

@@ -87,7 +87,6 @@ static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
void llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter);
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive//
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
@@ -2923,8 +2922,8 @@ public:
Value sOffset =
mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sTileStride);
Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset);
PTXBuilder builder;
PTXBuilder builder;
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a
// thread.
auto resArgs = builder.newListOperand(4, "=r");
@@ -2943,12 +2942,13 @@ public:
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
};
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
// The struct should have exactly the same element types.
Type elemType = resV4.getType().cast<LLVM::LLVMStructType>().getBody()[0];
return {extract_val(fp16x2Ty, resV4, getIntAttr(0)),
extract_val(fp16x2Ty, resV4, getIntAttr(1)),
extract_val(fp16x2Ty, resV4, getIntAttr(2)),
extract_val(fp16x2Ty, resV4, getIntAttr(3))};
return {extract_val(elemType, resV4, getIntAttr(0)),
extract_val(elemType, resV4, getIntAttr(1)),
extract_val(elemType, resV4, getIntAttr(2)),
extract_val(elemType, resV4, getIntAttr(3))};
} else if (elemBytes == 4 &&
needTrans) { // Use lds.32 to load tf32 matrices
Value ptr2 = getPtr(ptrIdx + 1);
@@ -2961,20 +2961,25 @@ public:
Value elems[4];
Type elemTy = type::f32Ty(ctx);
Type elemPtrTy = ptr_ty(elemTy);
if (kOrder == 1) {
elems[0] = load(gep(elemTy, ptr, sOffsetElemVal));
elems[1] = load(gep(elemTy, ptr2, sOffsetElemVal));
elems[2] = load(gep(elemTy, ptr, sOffsetArrElemVal));
elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal));
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
elems[1] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
elems[2] =
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
elems[3] =
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
} else {
elems[0] = load(gep(elemTy, ptr, sOffsetElemVal));
elems[2] = load(gep(elemTy, ptr2, sOffsetElemVal));
elems[1] = load(gep(elemTy, ptr, sOffsetArrElemVal));
elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal));
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
elems[2] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
elems[1] =
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
elems[3] =
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
}
return {elems[0], elems[1], elems[2], elems[3]};
} else if (elemBytes == 1 && needTrans) {
} else if (elemBytes == 1 && needTrans) { // work with int8
std::array<std::array<Value, 4>, 2> ptrs;
ptrs[0] = {
getPtr(ptrIdx),
@@ -3004,15 +3009,16 @@ public:
Value i8Elems[4][4];
Type elemTy = type::i8Ty(ctx);
Type elemPtrTy = ptr_ty(elemTy);
if (kOrder == 1) {
for (int i = 0; i < 2; ++i)
for (int j = 0; j < 4; ++j)
i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], sOffsetElemVal));
i8Elems[i][j] = load(gep(elemPtrTy, ptrs[i][j], sOffsetElemVal));
for (int i = 2; i < 4; ++i)
for (int j = 0; j < 4; ++j)
i8Elems[i][j] =
load(gep(elemTy, ptrs[i - 2][j], sOffsetArrElemVal));
load(gep(elemPtrTy, ptrs[i - 2][j], sOffsetArrElemVal));
for (int m = 0; m < 4; ++m) {
for (int e = 0; e < 4; ++e)
@@ -3022,13 +3028,13 @@ public:
}
} else { // k first
for (int j = 0; j < 4; ++j)
i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], sOffsetElemVal));
i8Elems[0][j] = load(gep(elemPtrTy, ptrs[0][j], sOffsetElemVal));
for (int j = 0; j < 4; ++j)
i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], sOffsetElemVal));
i8Elems[2][j] = load(gep(elemPtrTy, ptrs[1][j], sOffsetElemVal));
for (int j = 0; j < 4; ++j)
i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], sOffsetArrElemVal));
i8Elems[1][j] = load(gep(elemPtrTy, ptrs[0][j], sOffsetArrElemVal));
for (int j = 0; j < 4; ++j)
i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], sOffsetArrElemVal));
i8Elems[3][j] = load(gep(elemPtrTy, ptrs[1][j], sOffsetArrElemVal));
for (int m = 0; m < 4; ++m) {
for (int e = 0; e < 4; ++e)
@@ -3112,6 +3118,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
size_t reduceAxis = 1;
unsigned K = AShape[reduceAxis];
bool isOuter = K == 1;
bool isMMA = D.getType()
.cast<RankedTensorType>()
.getEncoding()
@@ -3123,11 +3130,13 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
.getEncoding()
.cast<MmaEncodingAttr>();
if (!isOuter && isMMA) {
bool isHMMA = isDotHMMA(op);
if (!isOuter && isMMA && isHMMA) {
if (mmaLayout.getVersion() == 1)
return convertMMA884(op, adaptor, rewriter);
if (mmaLayout.getVersion() == 2)
return convertMMA16816(op, adaptor, rewriter);
llvm::report_fatal_error(
"Unsupported MMA kind found when converting DotOp to LLVM.");
}
@@ -3140,6 +3149,49 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
"Unsupported DotOp found when converting TritonGPU to LLVM.");
}
// Tell whether a DotOp support HMMA.
// This is port from the master branch, the original logic is retained.
static bool isDotHMMA(DotOp op) {
auto a = op.a();
auto b = op.b();
auto c = op.c();
auto d = op.getResult();
auto aTensorTy = a.getType().cast<RankedTensorType>();
auto bTensorTy = b.getType().cast<RankedTensorType>();
auto cTensorTy = c.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>();
if (!dTensorTy.getEncoding().isa<MmaEncodingAttr>())
return false;
auto mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
auto aElemTy = aTensorTy.getElementType();
auto bElemTy = bTensorTy.getElementType();
assert((mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2) &&
"Unexpected MMA layout version found");
// Refer to mma section for the data type supported by Volta and Hopper
// Tensor Core in
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
return (aElemTy.isF16() && bElemTy.isF16()) ||
(aElemTy.isBF16() && bElemTy.isBF16()) ||
(aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() &&
mmaLayout.getVersion() >= 2) ||
(aElemTy.isInteger(8) && bElemTy.isInteger(8) &&
mmaLayout.getVersion() >= 2);
}
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
// We cannot get both the operand types(in TypeConverter), here we assume the
// types of both the operands are identical here.
// TODO[Superjomn]: Find a better way to implement it.
static bool isDotHMMA(TensorType operand, bool allowTF32, int mmaVersion) {
auto elemTy = operand.getElementType();
return elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && allowTF32 && mmaVersion >= 2) ||
(elemTy.isInteger(8) && mmaVersion >= 2);
}
private:
// Convert to mma.m16n8k16
LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adaptor,
@@ -3651,6 +3703,7 @@ struct MMA16816ConversionHelper {
std::function<void(int, int)> loadFn;
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
int numRepM = getNumRepM(aTensorTy, shape[0]);
int numRepK = getNumRepK(aTensorTy, shape[1]);
@@ -3766,6 +3819,7 @@ struct MMA16816ConversionHelper {
std::to_string(i)));
// reuse the output registers
}
mma(retArgs, aArgs, bArgs, cArgs);
Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType());
@@ -3773,9 +3827,10 @@ struct MMA16816ConversionHelper {
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
};
Type elemTy = mmaOut.getType().cast<LLVM::LLVMStructType>().getBody()[0];
for (int i = 0; i < 4; ++i)
fc[m * colsPerThread + 4 * n + i] =
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
extract_val(elemTy, mmaOut, getIntAttr(i));
};
for (int k = 0; k < numRepK; ++k)
@@ -3783,9 +3838,15 @@ struct MMA16816ConversionHelper {
for (int n = 0; n < numRepN; ++n)
callMma(2 * m, n, 2 * k);
Type resElemTy = dTensorTy.getElementType();
for (auto &elem : fc) {
elem = bitcast(elem, resElemTy);
}
// replace with new packed result
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fc.size(), type::f32Ty(ctx)));
ctx, SmallVector<Type>(fc.size(), resElemTy));
Value res = getStructFromElements(loc, fc, rewriter, structTy);
rewriter.replaceOp(op, res);
@@ -3821,9 +3882,7 @@ private:
tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
maxPhase, elemBytes, rewriter, typeConverter, loc);
SmallVector<Value> offs = loader.computeOffsets(warpId, lane);
const int numPtrs = loader.getNumPtr();
SmallVector<Value> ptrs(numPtrs);
Type smemPtrTy = helper.getShemPtrTy();
@@ -3835,6 +3894,7 @@ private:
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
ptrs, helper.getMatType(), helper.getShemPtrTy());
if (!needTrans) {
ld2(vals, a, b, ha0);
ld2(vals, a + 1, b, ha1);
@@ -3879,10 +3939,9 @@ private:
assert(!elems.empty());
Type fp16Ty = type::f16Ty(ctx);
Type fp16x2Ty = vec_ty(fp16Ty, 2);
Type elemTy = elems[0].getType();
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems.size(), fp16x2Ty));
ctx, SmallVector<Type>(elems.size(), elemTy));
auto result = getStructFromElements(loc, elems, rewriter, structTy);
return result;
}
@@ -3921,9 +3980,25 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
assert(mmaLayout);
bool isOuter{};
{
int K{};
if (dotOperandLayout.getOpIdx() == 0) // $a
K = dstTensorTy.getShape()[1];
else // $b
K = dstTensorTy.getShape()[0];
isOuter = K == 1;
}
// TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it
// is an attribute of DotOp.
bool allowTF32 = false;
bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, allowTF32,
mmaLayout.getVersion());
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
Value res;
if (mmaLayout.getVersion() == 2) {
if (!isOuter && mmaLayout.getVersion() == 2 && isHMMA) { // tensor core v2
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
rewriter, getTypeConverter(),
op.getLoc());
@@ -3935,7 +4010,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
// operand $b
res = mmaHelper.loadB(src, smemObj);
}
} else if (mmaLayout.getVersion() == 1) {
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
isHMMA) { // tensor core v1
DotOpMmaV1ConversionHelper helper(mmaLayout);
if (dotOperandLayout.getOpIdx() == 0) {
// operand $a
@@ -5076,8 +5152,8 @@ void ConvertTritonGPUToLLVM::initSharedMemory(
OpBuilder b(mod.getBodyRegion());
auto loc = mod.getLoc();
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
// Set array size 0 and external linkage indicates that we use dynamic shared
// allocation to allow a larger shared memory size for each kernel.
// Set array size 0 and external linkage indicates that we use dynamic
// shared allocation to allow a larger shared memory size for each kernel.
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0);
auto global = b.create<LLVM::GlobalOp>(
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External,

View File

@@ -117,10 +117,13 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
"BlockedEncodingAttr not implemented");
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet");
if (mmaLayout.getVersion() == 2)
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
if (mmaLayout.getVersion() == 1)
return {16 * mmaLayout.getWarpsPerCTA()[0],
16 * mmaLayout.getWarpsPerCTA()[1]};
assert(0 && "Unexpected MMA layout version found");
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
}

View File

@@ -55,6 +55,33 @@ def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [
[64, 128, 128, 1],
[128, 128, 128, 4],
[16, 8, 32, 1],
[32, 16, 64, 2],
[32, 16, 64, 4],
])
def test_gemm_no_scf_int8(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
a = torch.randint(-5, 5, (SIZE_M, SIZE_K), device='cuda', dtype=torch.int8)
b = torch.randint(-5, 5, (SIZE_K, SIZE_N), device='cuda', dtype=torch.int8)
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.int32)
grid = lambda META: (1, )
matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
M=SIZE_M, N=SIZE_N, K=SIZE_K,
num_warps=NUM_WARPS)
aa = a.cpu()
bb = b.cpu()
golden = torch.matmul(aa.float(), bb.float()).int()
torch.set_printoptions(profile="full")
torch.testing.assert_close(c.cpu(), golden, check_dtype=False)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
@@ -80,8 +107,6 @@ def matmul_kernel(
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, accumulator)
# TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment
def get_variant_golden(a, b):
SIZE_M = a.shape[0]