[Triton-MLIR][Backend] Fix convert_layout blocked->shared in non-default order (#876)

This PR fix the problem of TN/NT GEMM correctness when no SCF involved.
I'll continue to clean up getLinearIndex/getMultiDimIndex in a uniformed
way which should be benifical to avoid different kinds of order issues.
This is not fully done yet, just merge to sync the code.
This commit is contained in:
goostavz
2022-11-15 09:02:46 +08:00
committed by GitHub
parent 1eedaf7bec
commit c28cfd821b
2 changed files with 155 additions and 58 deletions

View File

@@ -347,7 +347,8 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
return llvmStruct; return llvmStruct;
} }
// Delinearize on compile-time consts, assuming the order is [n, .. 2, 1, 0] // TODO[goostavz]: to be deprecated
// delinearize supposing order is [n, .. , 2, 1, 0]
template <typename T> template <typename T>
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) { static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
// shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1} // shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
@@ -365,7 +366,40 @@ static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
return multiDimIndex; return multiDimIndex;
} }
// Linearize on compile-time consts, assuming the order is [n, .. 2, 1, 0] // delinearize supposing order is [0, 1, .. , n]
template <typename T>
static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
size_t rank = shape.size();
T accMul = product(shape.drop_back());
T linearRemain = linearIndex;
SmallVector<T> multiDimIndex(rank);
for (int i = rank - 1; i >= 0; --i) {
multiDimIndex[i] = linearRemain / accMul;
linearRemain = linearRemain % accMul;
if (i != 0) {
accMul = accMul / shape[i - 1];
}
}
return multiDimIndex;
}
template <typename T>
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
ArrayRef<unsigned> order) {
size_t rank = shape.size();
assert(rank == order.size());
auto reordered = reorder(shape, order);
auto reorderedMultiDim = getMultiDimIndexImpl<T>(linearIndex, reordered);
SmallVector<T> multiDim(rank);
for (unsigned i = 0; i < rank; ++i) {
multiDim[order[i]] = reorderedMultiDim[i];
}
return multiDim;
}
// TODO[goostavz]: to be deprecated
// linearize supposing order is [n, .. , 2, 1, 0]
template <typename T> template <typename T>
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) { static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
assert(multiDimIndex.size() == shape.size()); assert(multiDimIndex.size() == shape.size());
@@ -382,6 +416,30 @@ static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
return linearIndex; return linearIndex;
} }
template <typename T>
static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
assert(multiDimIndex.size() == shape.size());
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
size_t rank = shape.size();
T accMul = product(shape.drop_back());
T linearIndex = 0;
for (int i = rank - 1; i >= 0; --i) {
linearIndex += multiDimIndex[i] * accMul;
if (i != 0) {
accMul = accMul / shape[i - 1];
}
}
return linearIndex;
}
template <typename T>
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape,
ArrayRef<unsigned> order) {
assert(shape.size() == order.size());
return getLinearIndexImpl<T>(reorder(multiDimIndex, order),
reorder(shape, order));
}
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc, static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
Value ptr, Value val, Value pred) { Value ptr, Value val, Value pred) {
MLIRContext *ctx = rewriter.getContext(); MLIRContext *ctx = rewriter.getContext();
@@ -632,6 +690,7 @@ public:
auto sizePerThread = blockedLayout.getSizePerThread(); auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
auto order = blockedLayout.getOrder();
unsigned rank = shape.size(); unsigned rank = shape.size();
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout); SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
@@ -663,9 +722,9 @@ public:
unsigned linearNanoTileId = n / totalSizePerThread; unsigned linearNanoTileId = n / totalSizePerThread;
unsigned linearNanoTileElemId = n % totalSizePerThread; unsigned linearNanoTileElemId = n % totalSizePerThread;
SmallVector<unsigned> multiDimNanoTileId = SmallVector<unsigned> multiDimNanoTileId =
getMultiDimIndex<unsigned>(linearNanoTileId, tilesPerDim); getMultiDimIndex<unsigned>(linearNanoTileId, tilesPerDim, order);
SmallVector<unsigned> multiDimNanoTileElemId = SmallVector<unsigned> multiDimNanoTileElemId = getMultiDimIndex<unsigned>(
getMultiDimIndex<unsigned>(linearNanoTileElemId, sizePerThread); linearNanoTileElemId, sizePerThread, order);
for (unsigned k = 0; k < rank; ++k) { for (unsigned k = 0; k < rank; ++k) {
unsigned reorderedMultiDimId = unsigned reorderedMultiDimId =
multiDimNanoTileId[k] * multiDimNanoTileId[k] *
@@ -1881,8 +1940,6 @@ struct PrintfOpConversion
// currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64 // currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64
std::string getFormatSubstr(Value value) const { std::string getFormatSubstr(Value value) const {
Type type = value.getType(); Type type = value.getType();
unsigned width = type.getIntOrFloatBitWidth();
if (type.isa<LLVM::LLVMPointerType>()) { if (type.isa<LLVM::LLVMPointerType>()) {
return "%p"; return "%p";
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { } else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
@@ -1924,13 +1981,11 @@ struct PrintfOpConversion
promoteValue(ConversionPatternRewriter &rewriter, Value value) { promoteValue(ConversionPatternRewriter &rewriter, Value value) {
auto *context = rewriter.getContext(); auto *context = rewriter.getContext();
auto type = value.getType(); auto type = value.getType();
type.dump();
unsigned width = type.getIntOrFloatBitWidth();
Value newOp = value; Value newOp = value;
Type newType = type; Type newType = type;
bool bUnsigned = type.isUnsignedInteger(); bool bUnsigned = type.isUnsignedInteger();
if (type.isIntOrIndex() && width < 32) { if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) {
if (bUnsigned) { if (bUnsigned) {
newType = ui32_ty; newType = ui32_ty;
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType, newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
@@ -3057,23 +3112,24 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
} }
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads; unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>( auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread()); linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec; unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
multiDimIdxInNanoTile[inOrd[0]] /= minVec; multiDimIdxInNanoTile[inOrd[0]] /= minVec;
unsigned wordVecIdx = unsigned wordVecIdx =
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep); getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd);
wordVecs[wordVecIdx] = wordVecs[wordVecIdx] =
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos)); insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos));
if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) { if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) {
// end of replication, store the vectors into shared memory // end of replication, store the vectors into shared memory
unsigned linearRepIdx = i / srcAccumSizeInThreads; unsigned linearRepIdx = i / srcAccumSizeInThreads;
auto multiDimRepIdx = getMultiDimIndex<unsigned>(linearRepIdx, reps); auto multiDimRepIdx =
getMultiDimIndex<unsigned>(linearRepIdx, reps, inOrd);
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep; for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
++linearWordIdx) { ++linearWordIdx) {
// step 1: recover the multidim_index from the index of input_elements // step 1: recover the multidim_index from the index of input_elements
auto multiDimWordIdx = auto multiDimWordIdx =
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep); getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep, inOrd);
SmallVector<Value> multiDimIdx(2); SmallVector<Value> multiDimIdx(2);
auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] + auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] +
multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1); multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1);
@@ -3083,12 +3139,12 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
multiDimIdx[1] = add(multiDimOffsetFirstElem[1], idx_val(wordOffset1)); multiDimIdx[1] = add(multiDimOffsetFirstElem[1], idx_val(wordOffset1));
// step 2: do swizzling // step 2: do swizzling
Value remained = urem(multiDimIdx[inOrd[0]], outVecVal); Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
multiDimIdx[inOrd[0]] = udiv(multiDimIdx[inOrd[0]], outVecVal); multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
Value off_1 = mul(multiDimIdx[inOrd[1]], idx_val(srcShape[inOrd[0]])); Value off_1 = mul(multiDimIdx[outOrd[1]], idx_val(srcShape[outOrd[0]]));
Value phaseId = udiv(multiDimIdx[inOrd[1]], idx_val(perPhase)); Value phaseId = udiv(multiDimIdx[outOrd[1]], idx_val(perPhase));
phaseId = urem(phaseId, idx_val(maxPhase)); phaseId = urem(phaseId, idx_val(maxPhase));
Value off_0 = xor_(multiDimIdx[inOrd[0]], phaseId); Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
off_0 = mul(off_0, outVecVal); off_0 = mul(off_0, outVecVal);
remained = udiv(remained, minVecVal); remained = udiv(remained, minVecVal);
off_0 = add(off_0, mul(remained, minVecVal)); off_0 = add(off_0, mul(remained, minVecVal));

View File

@@ -30,18 +30,32 @@ def matmul_no_scf_kernel(
# TODO: num_warps could only be 4 for now # TODO: num_warps could only be 4 for now
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [ @pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
[128, 256, 32, 4], (shape, num_warps, trans_a, trans_b)
[256, 128, 16, 4], for shape in [
[128, 16, 32, 4], [128, 256, 32],
[32, 128, 64, 4], [256, 128, 16],
[128, 128, 64, 4], [128, 16, 32],
[64, 128, 128, 4], [32, 128, 64],
[64, 128, 128, 2], [128, 128, 64],
[64, 128, 128],
]
for num_warps in [2, 4]
for trans_a in [False, True]
for trans_b in [False, True]
]) ])
def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS): def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
SIZE_M, SIZE_N, SIZE_K = SHAPE
if (TRANS_A):
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((SIZE_N, SIZE_K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16) b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32) c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
grid = lambda META: (1, ) grid = lambda META: (1, )
matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
@@ -55,16 +69,32 @@ def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False) assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [ @pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
[64, 128, 128, 1], (shape, num_warps, trans_a, trans_b)
[128, 128, 128, 4], for shape in [
[16, 8, 32, 1], [64, 128, 128],
[32, 16, 64, 2], [128, 128, 128],
[32, 16, 64, 4], [16, 8, 32],
[32, 16, 64],
[32, 16, 64],
]
for num_warps in [1, 2, 4]
for trans_a in [False, True]
for trans_b in [False, True]
]) ])
def test_gemm_no_scf_int8(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS): def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
SIZE_M, SIZE_N, SIZE_K = SHAPE
if (TRANS_A):
a = torch.randint(-5, 5, (SIZE_K, SIZE_M), device='cuda', dtype=torch.int8).T
else:
a = torch.randint(-5, 5, (SIZE_M, SIZE_K), device='cuda', dtype=torch.int8) a = torch.randint(-5, 5, (SIZE_M, SIZE_K), device='cuda', dtype=torch.int8)
if (TRANS_B):
b = torch.randint(-5, 5, (SIZE_N, SIZE_K), device='cuda', dtype=torch.int8).T
else:
b = torch.randint(-5, 5, (SIZE_K, SIZE_N), device='cuda', dtype=torch.int8) b = torch.randint(-5, 5, (SIZE_K, SIZE_N), device='cuda', dtype=torch.int8)
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.int32) c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.int32)
grid = lambda META: (1, ) grid = lambda META: (1, )
@@ -125,28 +155,39 @@ def get_variant_golden(a, b):
return c_padded[:SIZE_M, :SIZE_N] return c_padded[:SIZE_M, :SIZE_N]
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [ @pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
# Non-forloop # Non-forloop
[64, 32, 64, 4, 64, 32, 64], [64, 32, 64, 4, 64, 32, 64, False, False],
[128, 64, 128, 4, 128, 64, 128], [128, 64, 128, 4, 128, 64, 128, False, False],
# K-Forloop # K-Forloop
[64, 32, 128, 4, 64, 32, 64], [64, 32, 128, 4, 64, 32, 64, False, False],
[128, 16, 128, 4, 128, 16, 32], [128, 16, 128, 4, 128, 16, 32, False, False],
[32, 16, 128, 4, 32, 16, 32], [32, 16, 128, 4, 32, 16, 32, False, False],
[32, 64, 128, 4, 32, 64, 32], [32, 64, 128, 4, 32, 64, 32, False, False],
[32, 128, 256, 4, 32, 128, 64], [32, 128, 256, 4, 32, 128, 64, False, False],
[64, 128, 64, 4, 64, 128, 32], [64, 128, 64, 4, 64, 128, 32, False, False],
[64, 64, 128, 4, 64, 64, 32], [64, 64, 128, 4, 64, 64, 32, False, False],
[128, 128, 64, 4, 128, 128, 32], [128, 128, 64, 4, 128, 128, 32, False, False],
[128, 128, 128, 4, 128, 128, 32], [128, 128, 128, 4, 128, 128, 32, False, False],
[128, 128, 256, 4, 128, 128, 64], [128, 128, 256, 4, 128, 128, 64, False, False],
[128, 256, 128, 4, 128, 256, 32], [128, 256, 128, 4, 128, 256, 32, False, False],
[256, 128, 64, 4, 256, 128, 16], [256, 128, 64, 4, 256, 128, 16, False, False],
[128, 64, 128, 4, 128, 64, 32], [128, 64, 128, 4, 128, 64, 32, False, False],
# TODO[goostavz]: fix these cases
#[128, 64, 128, 4, 128, 64, 32, True, False],
#[128, 64, 128, 4, 128, 64, 32, False, True],
]) ])
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K): def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
if (TRANS_A):
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((SIZE_N, SIZE_K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16) b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32) c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
grid = lambda META: (1, ) grid = lambda META: (1, )
matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,