[Triton-MLIR][Backend] Fix convert_layout blocked->shared in non-default order (#876)
This PR fix the problem of TN/NT GEMM correctness when no SCF involved. I'll continue to clean up getLinearIndex/getMultiDimIndex in a uniformed way which should be benifical to avoid different kinds of order issues. This is not fully done yet, just merge to sync the code.
This commit is contained in:
@@ -347,7 +347,8 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
return llvmStruct;
|
||||
}
|
||||
|
||||
// Delinearize on compile-time consts, assuming the order is [n, .. 2, 1, 0]
|
||||
// TODO[goostavz]: to be deprecated
|
||||
// delinearize supposing order is [n, .. , 2, 1, 0]
|
||||
template <typename T>
|
||||
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
|
||||
// shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
|
||||
@@ -365,7 +366,40 @@ static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
|
||||
return multiDimIndex;
|
||||
}
|
||||
|
||||
// Linearize on compile-time consts, assuming the order is [n, .. 2, 1, 0]
|
||||
// delinearize supposing order is [0, 1, .. , n]
|
||||
template <typename T>
|
||||
static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
|
||||
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
|
||||
size_t rank = shape.size();
|
||||
T accMul = product(shape.drop_back());
|
||||
T linearRemain = linearIndex;
|
||||
SmallVector<T> multiDimIndex(rank);
|
||||
for (int i = rank - 1; i >= 0; --i) {
|
||||
multiDimIndex[i] = linearRemain / accMul;
|
||||
linearRemain = linearRemain % accMul;
|
||||
if (i != 0) {
|
||||
accMul = accMul / shape[i - 1];
|
||||
}
|
||||
}
|
||||
return multiDimIndex;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
|
||||
ArrayRef<unsigned> order) {
|
||||
size_t rank = shape.size();
|
||||
assert(rank == order.size());
|
||||
auto reordered = reorder(shape, order);
|
||||
auto reorderedMultiDim = getMultiDimIndexImpl<T>(linearIndex, reordered);
|
||||
SmallVector<T> multiDim(rank);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
multiDim[order[i]] = reorderedMultiDim[i];
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
// TODO[goostavz]: to be deprecated
|
||||
// linearize supposing order is [n, .. , 2, 1, 0]
|
||||
template <typename T>
|
||||
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||
assert(multiDimIndex.size() == shape.size());
|
||||
@@ -382,6 +416,30 @@ static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||
assert(multiDimIndex.size() == shape.size());
|
||||
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
|
||||
size_t rank = shape.size();
|
||||
T accMul = product(shape.drop_back());
|
||||
T linearIndex = 0;
|
||||
for (int i = rank - 1; i >= 0; --i) {
|
||||
linearIndex += multiDimIndex[i] * accMul;
|
||||
if (i != 0) {
|
||||
accMul = accMul / shape[i - 1];
|
||||
}
|
||||
}
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape,
|
||||
ArrayRef<unsigned> order) {
|
||||
assert(shape.size() == order.size());
|
||||
return getLinearIndexImpl<T>(reorder(multiDimIndex, order),
|
||||
reorder(shape, order));
|
||||
}
|
||||
|
||||
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value ptr, Value val, Value pred) {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
@@ -632,6 +690,7 @@ public:
|
||||
auto sizePerThread = blockedLayout.getSizePerThread();
|
||||
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
||||
auto order = blockedLayout.getOrder();
|
||||
|
||||
unsigned rank = shape.size();
|
||||
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
|
||||
@@ -663,9 +722,9 @@ public:
|
||||
unsigned linearNanoTileId = n / totalSizePerThread;
|
||||
unsigned linearNanoTileElemId = n % totalSizePerThread;
|
||||
SmallVector<unsigned> multiDimNanoTileId =
|
||||
getMultiDimIndex<unsigned>(linearNanoTileId, tilesPerDim);
|
||||
SmallVector<unsigned> multiDimNanoTileElemId =
|
||||
getMultiDimIndex<unsigned>(linearNanoTileElemId, sizePerThread);
|
||||
getMultiDimIndex<unsigned>(linearNanoTileId, tilesPerDim, order);
|
||||
SmallVector<unsigned> multiDimNanoTileElemId = getMultiDimIndex<unsigned>(
|
||||
linearNanoTileElemId, sizePerThread, order);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
unsigned reorderedMultiDimId =
|
||||
multiDimNanoTileId[k] *
|
||||
@@ -1881,8 +1940,6 @@ struct PrintfOpConversion
|
||||
// currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64
|
||||
std::string getFormatSubstr(Value value) const {
|
||||
Type type = value.getType();
|
||||
unsigned width = type.getIntOrFloatBitWidth();
|
||||
|
||||
if (type.isa<LLVM::LLVMPointerType>()) {
|
||||
return "%p";
|
||||
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
|
||||
@@ -1924,13 +1981,11 @@ struct PrintfOpConversion
|
||||
promoteValue(ConversionPatternRewriter &rewriter, Value value) {
|
||||
auto *context = rewriter.getContext();
|
||||
auto type = value.getType();
|
||||
type.dump();
|
||||
unsigned width = type.getIntOrFloatBitWidth();
|
||||
Value newOp = value;
|
||||
Type newType = type;
|
||||
|
||||
bool bUnsigned = type.isUnsignedInteger();
|
||||
if (type.isIntOrIndex() && width < 32) {
|
||||
if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) {
|
||||
if (bUnsigned) {
|
||||
newType = ui32_ty;
|
||||
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
|
||||
@@ -3057,23 +3112,24 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
}
|
||||
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
|
||||
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
|
||||
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread());
|
||||
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
|
||||
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
|
||||
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
|
||||
unsigned wordVecIdx =
|
||||
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep);
|
||||
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd);
|
||||
wordVecs[wordVecIdx] =
|
||||
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos));
|
||||
|
||||
if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) {
|
||||
// end of replication, store the vectors into shared memory
|
||||
unsigned linearRepIdx = i / srcAccumSizeInThreads;
|
||||
auto multiDimRepIdx = getMultiDimIndex<unsigned>(linearRepIdx, reps);
|
||||
auto multiDimRepIdx =
|
||||
getMultiDimIndex<unsigned>(linearRepIdx, reps, inOrd);
|
||||
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
|
||||
++linearWordIdx) {
|
||||
// step 1: recover the multidim_index from the index of input_elements
|
||||
auto multiDimWordIdx =
|
||||
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep);
|
||||
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep, inOrd);
|
||||
SmallVector<Value> multiDimIdx(2);
|
||||
auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] +
|
||||
multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1);
|
||||
@@ -3083,12 +3139,12 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
multiDimIdx[1] = add(multiDimOffsetFirstElem[1], idx_val(wordOffset1));
|
||||
|
||||
// step 2: do swizzling
|
||||
Value remained = urem(multiDimIdx[inOrd[0]], outVecVal);
|
||||
multiDimIdx[inOrd[0]] = udiv(multiDimIdx[inOrd[0]], outVecVal);
|
||||
Value off_1 = mul(multiDimIdx[inOrd[1]], idx_val(srcShape[inOrd[0]]));
|
||||
Value phaseId = udiv(multiDimIdx[inOrd[1]], idx_val(perPhase));
|
||||
Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
|
||||
multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
|
||||
Value off_1 = mul(multiDimIdx[outOrd[1]], idx_val(srcShape[outOrd[0]]));
|
||||
Value phaseId = udiv(multiDimIdx[outOrd[1]], idx_val(perPhase));
|
||||
phaseId = urem(phaseId, idx_val(maxPhase));
|
||||
Value off_0 = xor_(multiDimIdx[inOrd[0]], phaseId);
|
||||
Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
|
||||
off_0 = mul(off_0, outVecVal);
|
||||
remained = udiv(remained, minVecVal);
|
||||
off_0 = add(off_0, mul(remained, minVecVal));
|
||||
|
@@ -30,18 +30,32 @@ def matmul_no_scf_kernel(
|
||||
# TODO: num_warps could only be 4 for now
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [
|
||||
[128, 256, 32, 4],
|
||||
[256, 128, 16, 4],
|
||||
[128, 16, 32, 4],
|
||||
[32, 128, 64, 4],
|
||||
[128, 128, 64, 4],
|
||||
[64, 128, 128, 4],
|
||||
[64, 128, 128, 2],
|
||||
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
|
||||
(shape, num_warps, trans_a, trans_b)
|
||||
for shape in [
|
||||
[128, 256, 32],
|
||||
[256, 128, 16],
|
||||
[128, 16, 32],
|
||||
[32, 128, 64],
|
||||
[128, 128, 64],
|
||||
[64, 128, 128],
|
||||
]
|
||||
for num_warps in [2, 4]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
])
|
||||
def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
|
||||
def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
SIZE_M, SIZE_N, SIZE_K = SHAPE
|
||||
if (TRANS_A):
|
||||
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
||||
|
||||
if (TRANS_B):
|
||||
b = torch.randn((SIZE_N, SIZE_K), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
||||
|
||||
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
||||
grid = lambda META: (1, )
|
||||
matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
@@ -55,16 +69,32 @@ def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
|
||||
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [
|
||||
[64, 128, 128, 1],
|
||||
[128, 128, 128, 4],
|
||||
[16, 8, 32, 1],
|
||||
[32, 16, 64, 2],
|
||||
[32, 16, 64, 4],
|
||||
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
|
||||
(shape, num_warps, trans_a, trans_b)
|
||||
for shape in [
|
||||
[64, 128, 128],
|
||||
[128, 128, 128],
|
||||
[16, 8, 32],
|
||||
[32, 16, 64],
|
||||
[32, 16, 64],
|
||||
]
|
||||
for num_warps in [1, 2, 4]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
])
|
||||
def test_gemm_no_scf_int8(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
|
||||
def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
SIZE_M, SIZE_N, SIZE_K = SHAPE
|
||||
|
||||
if (TRANS_A):
|
||||
a = torch.randint(-5, 5, (SIZE_K, SIZE_M), device='cuda', dtype=torch.int8).T
|
||||
else:
|
||||
a = torch.randint(-5, 5, (SIZE_M, SIZE_K), device='cuda', dtype=torch.int8)
|
||||
|
||||
if (TRANS_B):
|
||||
b = torch.randint(-5, 5, (SIZE_N, SIZE_K), device='cuda', dtype=torch.int8).T
|
||||
else:
|
||||
b = torch.randint(-5, 5, (SIZE_K, SIZE_N), device='cuda', dtype=torch.int8)
|
||||
|
||||
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.int32)
|
||||
|
||||
grid = lambda META: (1, )
|
||||
@@ -125,28 +155,39 @@ def get_variant_golden(a, b):
|
||||
return c_padded[:SIZE_M, :SIZE_N]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
|
||||
# Non-forloop
|
||||
[64, 32, 64, 4, 64, 32, 64],
|
||||
[128, 64, 128, 4, 128, 64, 128],
|
||||
[64, 32, 64, 4, 64, 32, 64, False, False],
|
||||
[128, 64, 128, 4, 128, 64, 128, False, False],
|
||||
# K-Forloop
|
||||
[64, 32, 128, 4, 64, 32, 64],
|
||||
[128, 16, 128, 4, 128, 16, 32],
|
||||
[32, 16, 128, 4, 32, 16, 32],
|
||||
[32, 64, 128, 4, 32, 64, 32],
|
||||
[32, 128, 256, 4, 32, 128, 64],
|
||||
[64, 128, 64, 4, 64, 128, 32],
|
||||
[64, 64, 128, 4, 64, 64, 32],
|
||||
[128, 128, 64, 4, 128, 128, 32],
|
||||
[128, 128, 128, 4, 128, 128, 32],
|
||||
[128, 128, 256, 4, 128, 128, 64],
|
||||
[128, 256, 128, 4, 128, 256, 32],
|
||||
[256, 128, 64, 4, 256, 128, 16],
|
||||
[128, 64, 128, 4, 128, 64, 32],
|
||||
[64, 32, 128, 4, 64, 32, 64, False, False],
|
||||
[128, 16, 128, 4, 128, 16, 32, False, False],
|
||||
[32, 16, 128, 4, 32, 16, 32, False, False],
|
||||
[32, 64, 128, 4, 32, 64, 32, False, False],
|
||||
[32, 128, 256, 4, 32, 128, 64, False, False],
|
||||
[64, 128, 64, 4, 64, 128, 32, False, False],
|
||||
[64, 64, 128, 4, 64, 64, 32, False, False],
|
||||
[128, 128, 64, 4, 128, 128, 32, False, False],
|
||||
[128, 128, 128, 4, 128, 128, 32, False, False],
|
||||
[128, 128, 256, 4, 128, 128, 64, False, False],
|
||||
[128, 256, 128, 4, 128, 256, 32, False, False],
|
||||
[256, 128, 64, 4, 256, 128, 16, False, False],
|
||||
[128, 64, 128, 4, 128, 64, 32, False, False],
|
||||
# TODO[goostavz]: fix these cases
|
||||
#[128, 64, 128, 4, 128, 64, 32, True, False],
|
||||
#[128, 64, 128, 4, 128, 64, 32, False, True],
|
||||
])
|
||||
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):
|
||||
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||
if (TRANS_A):
|
||||
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
||||
|
||||
if (TRANS_B):
|
||||
b = torch.randn((SIZE_N, SIZE_K), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
||||
|
||||
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
||||
grid = lambda META: (1, )
|
||||
matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
|
Reference in New Issue
Block a user