[Backend] Hacky fix of missing barrier in ConvertLayout blocked->shared (#803)
Barrier should be set by a separate pass, but it seems like there may be some bugs
This commit is contained in:
@@ -91,7 +91,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
|||||||
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
||||||
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
||||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||||
#define barrier rewriter.create<mlir::gpu::BarrierOp>(loc)
|
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||||
#define i32_ty rewriter.getIntegerType(32)
|
#define i32_ty rewriter.getIntegerType(32)
|
||||||
#define vec_ty(type, num) VectorType::get(num, type)
|
#define vec_ty(type, num) VectorType::get(num, type)
|
||||||
@@ -1877,7 +1877,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
|||||||
|
|
||||||
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
|
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
|
||||||
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
|
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
|
||||||
barrier;
|
barrier();
|
||||||
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
||||||
srcLayout.isa<MmaEncodingAttr>()) {
|
srcLayout.isa<MmaEncodingAttr>()) {
|
||||||
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
|
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
|
||||||
@@ -1887,7 +1887,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
|||||||
assert(0 && "ConvertLayout with input layout not implemented");
|
assert(0 && "ConvertLayout with input layout not implemented");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
barrier;
|
barrier();
|
||||||
if (dstLayout.isa<BlockedEncodingAttr>() ||
|
if (dstLayout.isa<BlockedEncodingAttr>() ||
|
||||||
dstLayout.isa<MmaEncodingAttr>()) {
|
dstLayout.isa<MmaEncodingAttr>()) {
|
||||||
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
|
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
|
||||||
@@ -1963,6 +1963,11 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
smemBase = bitcast(elemPtrTy, smemBase);
|
smemBase = bitcast(elemPtrTy, smemBase);
|
||||||
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
|
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
|
||||||
SmallVector<Value> wordVecs(numWordsEachRep);
|
SmallVector<Value> wordVecs(numWordsEachRep);
|
||||||
|
// TODO: We should get less barriers if it is handled by membar pass
|
||||||
|
// instead of the backend, since the later can only handle it in
|
||||||
|
// the most conservative way. However just keep for now and revisit
|
||||||
|
// in the future in case necessary.
|
||||||
|
barrier();
|
||||||
for (unsigned i = 0; i < numElems; ++i) {
|
for (unsigned i = 0; i < numElems; ++i) {
|
||||||
if (i % srcAccumSizeInThreads == 0) {
|
if (i % srcAccumSizeInThreads == 0) {
|
||||||
// start of a replication
|
// start of a replication
|
||||||
@@ -2016,8 +2021,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO: double confirm if the Barrier is necessary here
|
barrier();
|
||||||
barrier;
|
|
||||||
rewriter.replaceOp(op, smemBase);
|
rewriter.replaceOp(op, smemBase);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -3057,11 +3061,6 @@ struct MMA16816ConversionHelper {
|
|||||||
for (unsigned n = 0; n < numRepN; ++n)
|
for (unsigned n = 0; n < numRepN; ++n)
|
||||||
callMma(2 * m, n, 2 * k);
|
callMma(2 * m, n, 2 * k);
|
||||||
|
|
||||||
// NOTE, the barrier here is a temporary trick making the gemm with a
|
|
||||||
// k-forloop pass the precision test, or it will fail.
|
|
||||||
// TODO[Superjomn]: Fix with a more general and performance-friendly way.
|
|
||||||
barrier;
|
|
||||||
|
|
||||||
// replace with new packed result
|
// replace with new packed result
|
||||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||||
ctx, SmallVector<Type>(fc.size(), type::f32Ty(ctx)));
|
ctx, SmallVector<Type>(fc.size(), type::f32Ty(ctx)));
|
||||||
|
@@ -83,6 +83,23 @@ def matmul_kernel(
|
|||||||
# TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment
|
# TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment
|
||||||
|
|
||||||
|
|
||||||
|
def get_variant_golden(a, b):
|
||||||
|
SIZE_M = a.shape[0]
|
||||||
|
SIZE_K = a.shape[1]
|
||||||
|
SIZE_N = b.shape[1]
|
||||||
|
assert a.shape[1] == b.shape[0]
|
||||||
|
zero_M_K = torch.zeros((SIZE_M, SIZE_K)).cuda()
|
||||||
|
zero_3M_K = torch.zeros((3 * SIZE_M, SIZE_K)).cuda()
|
||||||
|
zero_K_N = torch.zeros((SIZE_K, SIZE_N)).cuda()
|
||||||
|
zero_3K_N = torch.zeros((3 * SIZE_K, SIZE_N)).cuda()
|
||||||
|
a_padded = torch.cat((a, zero_M_K, zero_M_K), 0)
|
||||||
|
a_padded = torch.cat((a_padded, zero_3M_K, zero_3M_K), 1)
|
||||||
|
b_padded = torch.cat((b, zero_K_N, zero_K_N), 0)
|
||||||
|
b_padded = torch.cat((b_padded, zero_3K_N, zero_3K_N), 1)
|
||||||
|
c_padded = torch.matmul(a_padded, b_padded)
|
||||||
|
return c_padded[:SIZE_M, :SIZE_N]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
|
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
|
||||||
# Non-forloop
|
# Non-forloop
|
||||||
[64, 32, 64, 4, 64, 32, 64],
|
[64, 32, 64, 4, 64, 32, 64],
|
||||||
@@ -94,8 +111,8 @@ def matmul_kernel(
|
|||||||
[32, 64, 128, 4, 32, 64, 32],
|
[32, 64, 128, 4, 32, 64, 32],
|
||||||
[32, 128, 256, 4, 32, 128, 64],
|
[32, 128, 256, 4, 32, 128, 64],
|
||||||
[64, 128, 64, 4, 64, 128, 32],
|
[64, 128, 64, 4, 64, 128, 32],
|
||||||
[128, 128, 64, 4, 128, 128, 32],
|
|
||||||
[64, 64, 128, 4, 64, 64, 32],
|
[64, 64, 128, 4, 64, 64, 32],
|
||||||
|
[128, 128, 64, 4, 128, 128, 32],
|
||||||
[128, 128, 128, 4, 128, 128, 32],
|
[128, 128, 128, 4, 128, 128, 32],
|
||||||
[128, 128, 256, 4, 128, 128, 64],
|
[128, 128, 256, 4, 128, 128, 64],
|
||||||
[128, 256, 128, 4, 128, 256, 32],
|
[128, 256, 128, 4, 128, 256, 32],
|
||||||
@@ -115,5 +132,15 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
|
|||||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||||
num_warps=NUM_WARPS)
|
num_warps=NUM_WARPS)
|
||||||
golden = torch.matmul(a, b)
|
golden = torch.matmul(a, b)
|
||||||
|
|
||||||
|
# It's not easy to get a proper error threshold in different size
|
||||||
|
# Here the gemm calculation is padded to a different size in order to get
|
||||||
|
# a variant version of the golden result. And the error between golden and
|
||||||
|
# golden_variant provide reference on selecting the proper rtol / atol.
|
||||||
|
golden_variant = get_variant_golden(a, b)
|
||||||
|
golden_diff = golden - golden_variant
|
||||||
|
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
|
||||||
|
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
|
||||||
|
|
||||||
torch.set_printoptions(profile="full")
|
torch.set_printoptions(profile="full")
|
||||||
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
|
||||||
|
Reference in New Issue
Block a user