diff --git a/runtime/src/accounts_db.rs b/runtime/src/accounts_db.rs index 458fd1178b..b08b418284 100644 --- a/runtime/src/accounts_db.rs +++ b/runtime/src/accounts_db.rs @@ -3624,6 +3624,7 @@ impl AccountsDB { fn compute_merkle_root_from_slices<'a, F>( total_hashes: usize, fanout: usize, + max_levels_per_pass: Option, get_hashes: F, ) -> Hash where @@ -3635,7 +3636,17 @@ impl AccountsDB { let mut time = Measure::start("time"); - let chunks = Self::div_ceil(total_hashes, fanout); + const THREE_LEVEL_OPTIMIZATION: usize = 3; // this '3' is dependent on the code structure below where we manually unroll + let target = fanout.pow(THREE_LEVEL_OPTIMIZATION as u32); + + // Only use the 3 level optimization if we have at least 4 levels of data. + // Otherwise, we'll be serializing a parallel operation. + let threshold = target * fanout; + let three_level = max_levels_per_pass.unwrap_or(usize::MAX) >= THREE_LEVEL_OPTIMIZATION + && total_hashes >= threshold; + let num_hashes_per_chunk = if three_level { target } else { fanout }; + + let chunks = Self::div_ceil(total_hashes, num_hashes_per_chunk); // initial fetch - could return entire slice let data: &[Hash] = get_hashes(0); @@ -3644,24 +3655,54 @@ impl AccountsDB { let result: Vec<_> = (0..chunks) .into_par_iter() .map(|i| { - let start_index = i * fanout; - let end_index = std::cmp::min(start_index + fanout, total_hashes); + let start_index = i * num_hashes_per_chunk; + let end_index = std::cmp::min(start_index + num_hashes_per_chunk, total_hashes); let mut hasher = Hasher::default(); let mut data_index = start_index; let mut data = data; let mut data_len = data_len; - for i in start_index..end_index { - if data_index >= data_len { - // fetch next slice - data = get_hashes(i); - data_len = data.len(); - data_index = 0; + if !three_level { + // 1 group of fanout + // The result of this loop is a single hash value from fanout input hashes. + for i in start_index..end_index { + if data_index >= data_len { + // fetch next slice + data = get_hashes(i); + data_len = data.len(); + data_index = 0; + } + hasher.hash(data[data_index].as_ref()); + data_index += 1; + } + } else { + // hash 3 levels of fanout simultaneously. + // The result of this loop is a single hash value from fanout^3 input hashes. + let mut i = start_index; + while i < end_index { + let mut hasher_j = Hasher::default(); + for _j in 0..fanout { + let mut hasher_k = Hasher::default(); + let end = std::cmp::min(end_index - i, fanout); + for _k in 0..end { + if data_index >= data_len { + // fetch next slice + data = get_hashes(i); + data_len = data.len(); + data_index = 0; + } + hasher_k.hash(data[data_index].as_ref()); + data_index += 1; + i += 1; + } + hasher_j.hash(hasher_k.result().as_ref()); + if i >= end_index { + break; + } + } + hasher.hash(hasher_j.result().as_ref()); } - - hasher.hash(data[data_index].as_ref()); - data_index += 1; } hasher.result() @@ -3823,10 +3864,12 @@ impl AccountsDB { let hash_total = cumulative_offsets.total_count; let total_lamports = *total_lamports.lock().unwrap(); let mut hash_time = Measure::start("hash"); - let accumulated_hash = - Self::compute_merkle_root_from_slices(hash_total, MERKLE_FANOUT, |start: usize| { - cumulative_offsets.get_slice(&hashes, start) - }); + let accumulated_hash = Self::compute_merkle_root_from_slices( + hash_total, + MERKLE_FANOUT, + None, + |start: usize| cumulative_offsets.get_slice(&hashes, start), + ); hash_time.stop(); datapoint_info!( "update_accounts_hash", @@ -4118,7 +4161,8 @@ impl AccountsDB { let offsets = CumulativeOffsets::from_raw_2d(&hashes); let get_slice = |start: usize| -> &[Hash] { offsets.get_slice_2d(&hashes, start) }; - let hash = Self::compute_merkle_root_from_slices(offsets.total_count, fanout, get_slice); + let hash = + Self::compute_merkle_root_from_slices(offsets.total_count, fanout, None, get_slice); hash_time.stop(); stats.hash_time_total_us += hash_time.as_us(); stats.hash_total = offsets.total_count; @@ -6459,55 +6503,44 @@ pub mod tests { fn test_hashing_larger(hashes: Vec<(Pubkey, Hash)>, fanout: usize) -> Hash { let result = AccountsDB::compute_merkle_root(hashes.clone(), fanout); - if hashes.len() >= fanout * fanout * fanout { - let reduced: Vec<_> = hashes.iter().map(|x| x.1).collect(); - let result2 = - AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, |start| { - &reduced[start..] - }); - assert_eq!(result, result2); - - let reduced2: Vec<_> = hashes.iter().map(|x| vec![x.1]).collect(); - let result2 = AccountsDB::flatten_hashes_and_hash( - vec![reduced2], - fanout, - &mut HashStats::default(), - ); - assert_eq!(result, result2); - - for left in 0..reduced.len() { - for right in left + 1..reduced.len() { - let src = vec![ - vec![reduced[0..left].to_vec(), reduced[left..right].to_vec()], - vec![reduced[right..].to_vec()], - ]; - let result2 = - AccountsDB::flatten_hashes_and_hash(src, fanout, &mut HashStats::default()); - assert_eq!(result, result2); - } - } - } + let reduced: Vec<_> = hashes.iter().map(|x| x.1).collect(); + let result2 = test_hashing(reduced, fanout); + assert_eq!(result, result2, "len: {}", hashes.len()); result } fn test_hashing(hashes: Vec, fanout: usize) -> Hash { let temp: Vec<_> = hashes.iter().map(|h| (Pubkey::default(), *h)).collect(); let result = AccountsDB::compute_merkle_root(temp, fanout); - if hashes.len() >= fanout * fanout * fanout { - let reduced: Vec<_> = hashes.clone(); - let result2 = - AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, |start| { - &reduced[start..] - }); - assert_eq!(result, result2, "len: {}", hashes.len()); + let reduced: Vec<_> = hashes.clone(); + let result2 = + AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, None, |start| { + &reduced[start..] + }); + assert_eq!(result, result2, "len: {}", hashes.len()); - let reduced2: Vec<_> = hashes.iter().map(|x| vec![*x]).collect(); - let result2 = AccountsDB::flatten_hashes_and_hash( - vec![reduced2], - fanout, - &mut HashStats::default(), - ); - assert_eq!(result, result2, "len: {}", hashes.len()); + let result2 = + AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, Some(1), |start| { + &reduced[start..] + }); + assert_eq!(result, result2, "len: {}", hashes.len()); + + let reduced2: Vec<_> = hashes.iter().map(|x| vec![*x]).collect(); + let result2 = + AccountsDB::flatten_hashes_and_hash(vec![reduced2], fanout, &mut HashStats::default()); + assert_eq!(result, result2, "len: {}", hashes.len()); + + let max = std::cmp::min(reduced.len(), fanout * 2); + for left in 0..max { + for right in left + 1..max { + let src = vec![ + vec![reduced[0..left].to_vec(), reduced[left..right].to_vec()], + vec![reduced[right..].to_vec()], + ]; + let result2 = + AccountsDB::flatten_hashes_and_hash(src, fanout, &mut HashStats::default()); + assert_eq!(result, result2); + } } result } @@ -6516,12 +6549,29 @@ pub mod tests { fn test_accountsdb_compute_merkle_root_large() { solana_logger::setup(); - let mut num = 100; - for _pass in 0..2 { - num *= 10; - let hashes: Vec<_> = (0..num).into_iter().map(|_| Hash::new_unique()).collect(); + // handle fanout^x -1, +0, +1 for a few 'x's + const FANOUT: usize = 3; + let mut hash_counts: Vec<_> = (1..6) + .map(|x| { + let mark = FANOUT.pow(x); + vec![mark - 1, mark, mark + 1] + }) + .flatten() + .collect(); - test_hashing(hashes, MERKLE_FANOUT); + // saturate the test space for threshold to threshold + target + // this hits right before we use the 3 deep optimization and all the way through all possible partial last chunks + let target = FANOUT.pow(3); + let threshold = target * FANOUT; + hash_counts.extend(threshold - 1..=threshold + target); + + for hash_count in hash_counts { + let hashes: Vec<_> = (0..hash_count) + .into_iter() + .map(|_| Hash::new_unique()) + .collect(); + + test_hashing(hashes, FANOUT); } }