[Triton-MLIR][Backend] Fix mma<v2> int8 precision error (#850)
Fix mma.16816 s8 precision error Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
This commit is contained in:
@@ -87,7 +87,6 @@ static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
|
||||
void llPrintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive//
|
||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
||||
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
||||
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
||||
@@ -2923,8 +2922,8 @@ public:
|
||||
Value sOffset =
|
||||
mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sTileStride);
|
||||
Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset);
|
||||
PTXBuilder builder;
|
||||
|
||||
PTXBuilder builder;
|
||||
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a
|
||||
// thread.
|
||||
auto resArgs = builder.newListOperand(4, "=r");
|
||||
@@ -2943,12 +2942,13 @@ public:
|
||||
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
|
||||
};
|
||||
|
||||
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
|
||||
// The struct should have exactly the same element types.
|
||||
Type elemType = resV4.getType().cast<LLVM::LLVMStructType>().getBody()[0];
|
||||
|
||||
return {extract_val(fp16x2Ty, resV4, getIntAttr(0)),
|
||||
extract_val(fp16x2Ty, resV4, getIntAttr(1)),
|
||||
extract_val(fp16x2Ty, resV4, getIntAttr(2)),
|
||||
extract_val(fp16x2Ty, resV4, getIntAttr(3))};
|
||||
return {extract_val(elemType, resV4, getIntAttr(0)),
|
||||
extract_val(elemType, resV4, getIntAttr(1)),
|
||||
extract_val(elemType, resV4, getIntAttr(2)),
|
||||
extract_val(elemType, resV4, getIntAttr(3))};
|
||||
} else if (elemBytes == 4 &&
|
||||
needTrans) { // Use lds.32 to load tf32 matrices
|
||||
Value ptr2 = getPtr(ptrIdx + 1);
|
||||
@@ -2961,20 +2961,25 @@ public:
|
||||
|
||||
Value elems[4];
|
||||
Type elemTy = type::f32Ty(ctx);
|
||||
Type elemPtrTy = ptr_ty(elemTy);
|
||||
if (kOrder == 1) {
|
||||
elems[0] = load(gep(elemTy, ptr, sOffsetElemVal));
|
||||
elems[1] = load(gep(elemTy, ptr2, sOffsetElemVal));
|
||||
elems[2] = load(gep(elemTy, ptr, sOffsetArrElemVal));
|
||||
elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal));
|
||||
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
|
||||
elems[1] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
|
||||
elems[2] =
|
||||
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
elems[3] =
|
||||
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
} else {
|
||||
elems[0] = load(gep(elemTy, ptr, sOffsetElemVal));
|
||||
elems[2] = load(gep(elemTy, ptr2, sOffsetElemVal));
|
||||
elems[1] = load(gep(elemTy, ptr, sOffsetArrElemVal));
|
||||
elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal));
|
||||
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
|
||||
elems[2] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
|
||||
elems[1] =
|
||||
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
elems[3] =
|
||||
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
}
|
||||
|
||||
return {elems[0], elems[1], elems[2], elems[3]};
|
||||
} else if (elemBytes == 1 && needTrans) {
|
||||
|
||||
} else if (elemBytes == 1 && needTrans) { // work with int8
|
||||
std::array<std::array<Value, 4>, 2> ptrs;
|
||||
ptrs[0] = {
|
||||
getPtr(ptrIdx),
|
||||
@@ -3004,15 +3009,16 @@ public:
|
||||
|
||||
Value i8Elems[4][4];
|
||||
Type elemTy = type::i8Ty(ctx);
|
||||
Type elemPtrTy = ptr_ty(elemTy);
|
||||
if (kOrder == 1) {
|
||||
for (int i = 0; i < 2; ++i)
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], sOffsetElemVal));
|
||||
i8Elems[i][j] = load(gep(elemPtrTy, ptrs[i][j], sOffsetElemVal));
|
||||
|
||||
for (int i = 2; i < 4; ++i)
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[i][j] =
|
||||
load(gep(elemTy, ptrs[i - 2][j], sOffsetArrElemVal));
|
||||
load(gep(elemPtrTy, ptrs[i - 2][j], sOffsetArrElemVal));
|
||||
|
||||
for (int m = 0; m < 4; ++m) {
|
||||
for (int e = 0; e < 4; ++e)
|
||||
@@ -3022,13 +3028,13 @@ public:
|
||||
}
|
||||
} else { // k first
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], sOffsetElemVal));
|
||||
i8Elems[0][j] = load(gep(elemPtrTy, ptrs[0][j], sOffsetElemVal));
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], sOffsetElemVal));
|
||||
i8Elems[2][j] = load(gep(elemPtrTy, ptrs[1][j], sOffsetElemVal));
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], sOffsetArrElemVal));
|
||||
i8Elems[1][j] = load(gep(elemPtrTy, ptrs[0][j], sOffsetArrElemVal));
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], sOffsetArrElemVal));
|
||||
i8Elems[3][j] = load(gep(elemPtrTy, ptrs[1][j], sOffsetArrElemVal));
|
||||
|
||||
for (int m = 0; m < 4; ++m) {
|
||||
for (int e = 0; e < 4; ++e)
|
||||
@@ -3112,6 +3118,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
size_t reduceAxis = 1;
|
||||
unsigned K = AShape[reduceAxis];
|
||||
bool isOuter = K == 1;
|
||||
|
||||
bool isMMA = D.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
@@ -3123,11 +3130,13 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
.getEncoding()
|
||||
.cast<MmaEncodingAttr>();
|
||||
|
||||
if (!isOuter && isMMA) {
|
||||
bool isHMMA = isDotHMMA(op);
|
||||
if (!isOuter && isMMA && isHMMA) {
|
||||
if (mmaLayout.getVersion() == 1)
|
||||
return convertMMA884(op, adaptor, rewriter);
|
||||
if (mmaLayout.getVersion() == 2)
|
||||
return convertMMA16816(op, adaptor, rewriter);
|
||||
|
||||
llvm::report_fatal_error(
|
||||
"Unsupported MMA kind found when converting DotOp to LLVM.");
|
||||
}
|
||||
@@ -3140,6 +3149,49 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
"Unsupported DotOp found when converting TritonGPU to LLVM.");
|
||||
}
|
||||
|
||||
// Tell whether a DotOp support HMMA.
|
||||
// This is port from the master branch, the original logic is retained.
|
||||
static bool isDotHMMA(DotOp op) {
|
||||
auto a = op.a();
|
||||
auto b = op.b();
|
||||
auto c = op.c();
|
||||
auto d = op.getResult();
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||
auto cTensorTy = c.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
|
||||
if (!dTensorTy.getEncoding().isa<MmaEncodingAttr>())
|
||||
return false;
|
||||
|
||||
auto mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
auto aElemTy = aTensorTy.getElementType();
|
||||
auto bElemTy = bTensorTy.getElementType();
|
||||
|
||||
assert((mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2) &&
|
||||
"Unexpected MMA layout version found");
|
||||
// Refer to mma section for the data type supported by Volta and Hopper
|
||||
// Tensor Core in
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
||||
return (aElemTy.isF16() && bElemTy.isF16()) ||
|
||||
(aElemTy.isBF16() && bElemTy.isBF16()) ||
|
||||
(aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() &&
|
||||
mmaLayout.getVersion() >= 2) ||
|
||||
(aElemTy.isInteger(8) && bElemTy.isInteger(8) &&
|
||||
mmaLayout.getVersion() >= 2);
|
||||
}
|
||||
|
||||
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
|
||||
// We cannot get both the operand types(in TypeConverter), here we assume the
|
||||
// types of both the operands are identical here.
|
||||
// TODO[Superjomn]: Find a better way to implement it.
|
||||
static bool isDotHMMA(TensorType operand, bool allowTF32, int mmaVersion) {
|
||||
auto elemTy = operand.getElementType();
|
||||
return elemTy.isF16() || elemTy.isBF16() ||
|
||||
(elemTy.isF32() && allowTF32 && mmaVersion >= 2) ||
|
||||
(elemTy.isInteger(8) && mmaVersion >= 2);
|
||||
}
|
||||
|
||||
private:
|
||||
// Convert to mma.m16n8k16
|
||||
LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adaptor,
|
||||
@@ -3651,6 +3703,7 @@ struct MMA16816ConversionHelper {
|
||||
std::function<void(int, int)> loadFn;
|
||||
auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy);
|
||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy);
|
||||
|
||||
int numRepM = getNumRepM(aTensorTy, shape[0]);
|
||||
int numRepK = getNumRepK(aTensorTy, shape[1]);
|
||||
|
||||
@@ -3766,6 +3819,7 @@ struct MMA16816ConversionHelper {
|
||||
std::to_string(i)));
|
||||
// reuse the output registers
|
||||
}
|
||||
|
||||
mma(retArgs, aArgs, bArgs, cArgs);
|
||||
Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType());
|
||||
|
||||
@@ -3773,9 +3827,10 @@ struct MMA16816ConversionHelper {
|
||||
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
|
||||
};
|
||||
|
||||
Type elemTy = mmaOut.getType().cast<LLVM::LLVMStructType>().getBody()[0];
|
||||
for (int i = 0; i < 4; ++i)
|
||||
fc[m * colsPerThread + 4 * n + i] =
|
||||
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
|
||||
extract_val(elemTy, mmaOut, getIntAttr(i));
|
||||
};
|
||||
|
||||
for (int k = 0; k < numRepK; ++k)
|
||||
@@ -3783,9 +3838,15 @@ struct MMA16816ConversionHelper {
|
||||
for (int n = 0; n < numRepN; ++n)
|
||||
callMma(2 * m, n, 2 * k);
|
||||
|
||||
Type resElemTy = dTensorTy.getElementType();
|
||||
|
||||
for (auto &elem : fc) {
|
||||
elem = bitcast(elem, resElemTy);
|
||||
}
|
||||
|
||||
// replace with new packed result
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(fc.size(), type::f32Ty(ctx)));
|
||||
ctx, SmallVector<Type>(fc.size(), resElemTy));
|
||||
Value res = getStructFromElements(loc, fc, rewriter, structTy);
|
||||
rewriter.replaceOp(op, res);
|
||||
|
||||
@@ -3821,9 +3882,7 @@ private:
|
||||
tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
|
||||
maxPhase, elemBytes, rewriter, typeConverter, loc);
|
||||
SmallVector<Value> offs = loader.computeOffsets(warpId, lane);
|
||||
|
||||
const int numPtrs = loader.getNumPtr();
|
||||
|
||||
SmallVector<Value> ptrs(numPtrs);
|
||||
|
||||
Type smemPtrTy = helper.getShemPtrTy();
|
||||
@@ -3835,6 +3894,7 @@ private:
|
||||
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
|
||||
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
|
||||
ptrs, helper.getMatType(), helper.getShemPtrTy());
|
||||
|
||||
if (!needTrans) {
|
||||
ld2(vals, a, b, ha0);
|
||||
ld2(vals, a + 1, b, ha1);
|
||||
@@ -3879,10 +3939,9 @@ private:
|
||||
|
||||
assert(!elems.empty());
|
||||
|
||||
Type fp16Ty = type::f16Ty(ctx);
|
||||
Type fp16x2Ty = vec_ty(fp16Ty, 2);
|
||||
Type elemTy = elems[0].getType();
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems.size(), fp16x2Ty));
|
||||
ctx, SmallVector<Type>(elems.size(), elemTy));
|
||||
auto result = getStructFromElements(loc, elems, rewriter, structTy);
|
||||
return result;
|
||||
}
|
||||
@@ -3921,9 +3980,25 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
|
||||
assert(mmaLayout);
|
||||
|
||||
bool isOuter{};
|
||||
{
|
||||
int K{};
|
||||
if (dotOperandLayout.getOpIdx() == 0) // $a
|
||||
K = dstTensorTy.getShape()[1];
|
||||
else // $b
|
||||
K = dstTensorTy.getShape()[0];
|
||||
isOuter = K == 1;
|
||||
}
|
||||
|
||||
// TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it
|
||||
// is an attribute of DotOp.
|
||||
bool allowTF32 = false;
|
||||
bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, allowTF32,
|
||||
mmaLayout.getVersion());
|
||||
|
||||
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
|
||||
Value res;
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
if (!isOuter && mmaLayout.getVersion() == 2 && isHMMA) { // tensor core v2
|
||||
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
||||
rewriter, getTypeConverter(),
|
||||
op.getLoc());
|
||||
@@ -3935,7 +4010,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
// operand $b
|
||||
res = mmaHelper.loadB(src, smemObj);
|
||||
}
|
||||
} else if (mmaLayout.getVersion() == 1) {
|
||||
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
||||
isHMMA) { // tensor core v1
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
if (dotOperandLayout.getOpIdx() == 0) {
|
||||
// operand $a
|
||||
@@ -5076,8 +5152,8 @@ void ConvertTritonGPUToLLVM::initSharedMemory(
|
||||
OpBuilder b(mod.getBodyRegion());
|
||||
auto loc = mod.getLoc();
|
||||
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
|
||||
// Set array size 0 and external linkage indicates that we use dynamic shared
|
||||
// allocation to allow a larger shared memory size for each kernel.
|
||||
// Set array size 0 and external linkage indicates that we use dynamic
|
||||
// shared allocation to allow a larger shared memory size for each kernel.
|
||||
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0);
|
||||
auto global = b.create<LLVM::GlobalOp>(
|
||||
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External,
|
||||
|
@@ -117,10 +117,13 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
"BlockedEncodingAttr not implemented");
|
||||
}
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(mmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout version = 1 is not implemented yet");
|
||||
if (mmaLayout.getVersion() == 2)
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
8 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
if (mmaLayout.getVersion() == 1)
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
16 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
assert(0 && "Unexpected MMA layout version found");
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||
}
|
||||
|
@@ -55,6 +55,33 @@ def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
|
||||
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [
|
||||
[64, 128, 128, 1],
|
||||
[128, 128, 128, 4],
|
||||
[16, 8, 32, 1],
|
||||
[32, 16, 64, 2],
|
||||
[32, 16, 64, 4],
|
||||
])
|
||||
def test_gemm_no_scf_int8(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
|
||||
a = torch.randint(-5, 5, (SIZE_M, SIZE_K), device='cuda', dtype=torch.int8)
|
||||
b = torch.randint(-5, 5, (SIZE_K, SIZE_N), device='cuda', dtype=torch.int8)
|
||||
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.int32)
|
||||
|
||||
grid = lambda META: (1, )
|
||||
matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
M=SIZE_M, N=SIZE_N, K=SIZE_K,
|
||||
num_warps=NUM_WARPS)
|
||||
|
||||
aa = a.cpu()
|
||||
bb = b.cpu()
|
||||
golden = torch.matmul(aa.float(), bb.float()).int()
|
||||
torch.set_printoptions(profile="full")
|
||||
torch.testing.assert_close(c.cpu(), golden, check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
@@ -80,8 +107,6 @@ def matmul_kernel(
|
||||
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
||||
tl.store(c_ptrs, accumulator)
|
||||
|
||||
# TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment
|
||||
|
||||
|
||||
def get_variant_golden(a, b):
|
||||
SIZE_M = a.shape[0]
|
||||
|
Reference in New Issue
Block a user