compute merkle root on chunks of fanout^3 (#15344)
* compute merkle root on chunks of fanout^3 * improve test_accountsdb_compute_merkle_root_large
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							ba02452d75
						
					
				
				
					commit
					8367740ff9
				
			@@ -3624,6 +3624,7 @@ impl AccountsDB {
 | 
			
		||||
    fn compute_merkle_root_from_slices<'a, F>(
 | 
			
		||||
        total_hashes: usize,
 | 
			
		||||
        fanout: usize,
 | 
			
		||||
        max_levels_per_pass: Option<usize>,
 | 
			
		||||
        get_hashes: F,
 | 
			
		||||
    ) -> Hash
 | 
			
		||||
    where
 | 
			
		||||
@@ -3635,7 +3636,17 @@ impl AccountsDB {
 | 
			
		||||
 | 
			
		||||
        let mut time = Measure::start("time");
 | 
			
		||||
 | 
			
		||||
        let chunks = Self::div_ceil(total_hashes, fanout);
 | 
			
		||||
        const THREE_LEVEL_OPTIMIZATION: usize = 3; // this '3' is dependent on the code structure below where we manually unroll
 | 
			
		||||
        let target = fanout.pow(THREE_LEVEL_OPTIMIZATION as u32);
 | 
			
		||||
 | 
			
		||||
        // Only use the 3 level optimization if we have at least 4 levels of data.
 | 
			
		||||
        // Otherwise, we'll be serializing a parallel operation.
 | 
			
		||||
        let threshold = target * fanout;
 | 
			
		||||
        let three_level = max_levels_per_pass.unwrap_or(usize::MAX) >= THREE_LEVEL_OPTIMIZATION
 | 
			
		||||
            && total_hashes >= threshold;
 | 
			
		||||
        let num_hashes_per_chunk = if three_level { target } else { fanout };
 | 
			
		||||
 | 
			
		||||
        let chunks = Self::div_ceil(total_hashes, num_hashes_per_chunk);
 | 
			
		||||
 | 
			
		||||
        // initial fetch - could return entire slice
 | 
			
		||||
        let data: &[Hash] = get_hashes(0);
 | 
			
		||||
@@ -3644,14 +3655,17 @@ impl AccountsDB {
 | 
			
		||||
        let result: Vec<_> = (0..chunks)
 | 
			
		||||
            .into_par_iter()
 | 
			
		||||
            .map(|i| {
 | 
			
		||||
                let start_index = i * fanout;
 | 
			
		||||
                let end_index = std::cmp::min(start_index + fanout, total_hashes);
 | 
			
		||||
                let start_index = i * num_hashes_per_chunk;
 | 
			
		||||
                let end_index = std::cmp::min(start_index + num_hashes_per_chunk, total_hashes);
 | 
			
		||||
 | 
			
		||||
                let mut hasher = Hasher::default();
 | 
			
		||||
                let mut data_index = start_index;
 | 
			
		||||
                let mut data = data;
 | 
			
		||||
                let mut data_len = data_len;
 | 
			
		||||
 | 
			
		||||
                if !three_level {
 | 
			
		||||
                    // 1 group of fanout
 | 
			
		||||
                    // The result of this loop is a single hash value from fanout input hashes.
 | 
			
		||||
                    for i in start_index..end_index {
 | 
			
		||||
                        if data_index >= data_len {
 | 
			
		||||
                            // fetch next slice
 | 
			
		||||
@@ -3659,10 +3673,37 @@ impl AccountsDB {
 | 
			
		||||
                            data_len = data.len();
 | 
			
		||||
                            data_index = 0;
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                        hasher.hash(data[data_index].as_ref());
 | 
			
		||||
                        data_index += 1;
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                    // hash 3 levels of fanout simultaneously.
 | 
			
		||||
                    // The result of this loop is a single hash value from fanout^3 input hashes.
 | 
			
		||||
                    let mut i = start_index;
 | 
			
		||||
                    while i < end_index {
 | 
			
		||||
                        let mut hasher_j = Hasher::default();
 | 
			
		||||
                        for _j in 0..fanout {
 | 
			
		||||
                            let mut hasher_k = Hasher::default();
 | 
			
		||||
                            let end = std::cmp::min(end_index - i, fanout);
 | 
			
		||||
                            for _k in 0..end {
 | 
			
		||||
                                if data_index >= data_len {
 | 
			
		||||
                                    // fetch next slice
 | 
			
		||||
                                    data = get_hashes(i);
 | 
			
		||||
                                    data_len = data.len();
 | 
			
		||||
                                    data_index = 0;
 | 
			
		||||
                                }
 | 
			
		||||
                                hasher_k.hash(data[data_index].as_ref());
 | 
			
		||||
                                data_index += 1;
 | 
			
		||||
                                i += 1;
 | 
			
		||||
                            }
 | 
			
		||||
                            hasher_j.hash(hasher_k.result().as_ref());
 | 
			
		||||
                            if i >= end_index {
 | 
			
		||||
                                break;
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                        hasher.hash(hasher_j.result().as_ref());
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                hasher.result()
 | 
			
		||||
            })
 | 
			
		||||
@@ -3823,10 +3864,12 @@ impl AccountsDB {
 | 
			
		||||
        let hash_total = cumulative_offsets.total_count;
 | 
			
		||||
        let total_lamports = *total_lamports.lock().unwrap();
 | 
			
		||||
        let mut hash_time = Measure::start("hash");
 | 
			
		||||
        let accumulated_hash =
 | 
			
		||||
            Self::compute_merkle_root_from_slices(hash_total, MERKLE_FANOUT, |start: usize| {
 | 
			
		||||
                cumulative_offsets.get_slice(&hashes, start)
 | 
			
		||||
            });
 | 
			
		||||
        let accumulated_hash = Self::compute_merkle_root_from_slices(
 | 
			
		||||
            hash_total,
 | 
			
		||||
            MERKLE_FANOUT,
 | 
			
		||||
            None,
 | 
			
		||||
            |start: usize| cumulative_offsets.get_slice(&hashes, start),
 | 
			
		||||
        );
 | 
			
		||||
        hash_time.stop();
 | 
			
		||||
        datapoint_info!(
 | 
			
		||||
            "update_accounts_hash",
 | 
			
		||||
@@ -4118,7 +4161,8 @@ impl AccountsDB {
 | 
			
		||||
        let offsets = CumulativeOffsets::from_raw_2d(&hashes);
 | 
			
		||||
 | 
			
		||||
        let get_slice = |start: usize| -> &[Hash] { offsets.get_slice_2d(&hashes, start) };
 | 
			
		||||
        let hash = Self::compute_merkle_root_from_slices(offsets.total_count, fanout, get_slice);
 | 
			
		||||
        let hash =
 | 
			
		||||
            Self::compute_merkle_root_from_slices(offsets.total_count, fanout, None, get_slice);
 | 
			
		||||
        hash_time.stop();
 | 
			
		||||
        stats.hash_time_total_us += hash_time.as_us();
 | 
			
		||||
        stats.hash_total = offsets.total_count;
 | 
			
		||||
@@ -6459,24 +6503,36 @@ pub mod tests {
 | 
			
		||||
 | 
			
		||||
    fn test_hashing_larger(hashes: Vec<(Pubkey, Hash)>, fanout: usize) -> Hash {
 | 
			
		||||
        let result = AccountsDB::compute_merkle_root(hashes.clone(), fanout);
 | 
			
		||||
        if hashes.len() >= fanout * fanout * fanout {
 | 
			
		||||
        let reduced: Vec<_> = hashes.iter().map(|x| x.1).collect();
 | 
			
		||||
        let result2 = test_hashing(reduced, fanout);
 | 
			
		||||
        assert_eq!(result, result2, "len: {}", hashes.len());
 | 
			
		||||
        result
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn test_hashing(hashes: Vec<Hash>, fanout: usize) -> Hash {
 | 
			
		||||
        let temp: Vec<_> = hashes.iter().map(|h| (Pubkey::default(), *h)).collect();
 | 
			
		||||
        let result = AccountsDB::compute_merkle_root(temp, fanout);
 | 
			
		||||
        let reduced: Vec<_> = hashes.clone();
 | 
			
		||||
        let result2 =
 | 
			
		||||
                AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, |start| {
 | 
			
		||||
            AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, None, |start| {
 | 
			
		||||
                &reduced[start..]
 | 
			
		||||
            });
 | 
			
		||||
            assert_eq!(result, result2);
 | 
			
		||||
        assert_eq!(result, result2, "len: {}", hashes.len());
 | 
			
		||||
 | 
			
		||||
            let reduced2: Vec<_> = hashes.iter().map(|x| vec![x.1]).collect();
 | 
			
		||||
            let result2 = AccountsDB::flatten_hashes_and_hash(
 | 
			
		||||
                vec![reduced2],
 | 
			
		||||
                fanout,
 | 
			
		||||
                &mut HashStats::default(),
 | 
			
		||||
            );
 | 
			
		||||
            assert_eq!(result, result2);
 | 
			
		||||
        let result2 =
 | 
			
		||||
            AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, Some(1), |start| {
 | 
			
		||||
                &reduced[start..]
 | 
			
		||||
            });
 | 
			
		||||
        assert_eq!(result, result2, "len: {}", hashes.len());
 | 
			
		||||
 | 
			
		||||
            for left in 0..reduced.len() {
 | 
			
		||||
                for right in left + 1..reduced.len() {
 | 
			
		||||
        let reduced2: Vec<_> = hashes.iter().map(|x| vec![*x]).collect();
 | 
			
		||||
        let result2 =
 | 
			
		||||
            AccountsDB::flatten_hashes_and_hash(vec![reduced2], fanout, &mut HashStats::default());
 | 
			
		||||
        assert_eq!(result, result2, "len: {}", hashes.len());
 | 
			
		||||
 | 
			
		||||
        let max = std::cmp::min(reduced.len(), fanout * 2);
 | 
			
		||||
        for left in 0..max {
 | 
			
		||||
            for right in left + 1..max {
 | 
			
		||||
                let src = vec![
 | 
			
		||||
                    vec![reduced[0..left].to_vec(), reduced[left..right].to_vec()],
 | 
			
		||||
                    vec![reduced[right..].to_vec()],
 | 
			
		||||
@@ -6486,29 +6542,6 @@ pub mod tests {
 | 
			
		||||
                assert_eq!(result, result2);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        }
 | 
			
		||||
        result
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn test_hashing(hashes: Vec<Hash>, fanout: usize) -> Hash {
 | 
			
		||||
        let temp: Vec<_> = hashes.iter().map(|h| (Pubkey::default(), *h)).collect();
 | 
			
		||||
        let result = AccountsDB::compute_merkle_root(temp, fanout);
 | 
			
		||||
        if hashes.len() >= fanout * fanout * fanout {
 | 
			
		||||
            let reduced: Vec<_> = hashes.clone();
 | 
			
		||||
            let result2 =
 | 
			
		||||
                AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, |start| {
 | 
			
		||||
                    &reduced[start..]
 | 
			
		||||
                });
 | 
			
		||||
            assert_eq!(result, result2, "len: {}", hashes.len());
 | 
			
		||||
 | 
			
		||||
            let reduced2: Vec<_> = hashes.iter().map(|x| vec![*x]).collect();
 | 
			
		||||
            let result2 = AccountsDB::flatten_hashes_and_hash(
 | 
			
		||||
                vec![reduced2],
 | 
			
		||||
                fanout,
 | 
			
		||||
                &mut HashStats::default(),
 | 
			
		||||
            );
 | 
			
		||||
            assert_eq!(result, result2, "len: {}", hashes.len());
 | 
			
		||||
        }
 | 
			
		||||
        result
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -6516,12 +6549,29 @@ pub mod tests {
 | 
			
		||||
    fn test_accountsdb_compute_merkle_root_large() {
 | 
			
		||||
        solana_logger::setup();
 | 
			
		||||
 | 
			
		||||
        let mut num = 100;
 | 
			
		||||
        for _pass in 0..2 {
 | 
			
		||||
            num *= 10;
 | 
			
		||||
            let hashes: Vec<_> = (0..num).into_iter().map(|_| Hash::new_unique()).collect();
 | 
			
		||||
        // handle fanout^x -1, +0, +1 for a few 'x's
 | 
			
		||||
        const FANOUT: usize = 3;
 | 
			
		||||
        let mut hash_counts: Vec<_> = (1..6)
 | 
			
		||||
            .map(|x| {
 | 
			
		||||
                let mark = FANOUT.pow(x);
 | 
			
		||||
                vec![mark - 1, mark, mark + 1]
 | 
			
		||||
            })
 | 
			
		||||
            .flatten()
 | 
			
		||||
            .collect();
 | 
			
		||||
 | 
			
		||||
            test_hashing(hashes, MERKLE_FANOUT);
 | 
			
		||||
        // saturate the test space for threshold to threshold + target
 | 
			
		||||
        // this hits right before we use the 3 deep optimization and all the way through all possible partial last chunks
 | 
			
		||||
        let target = FANOUT.pow(3);
 | 
			
		||||
        let threshold = target * FANOUT;
 | 
			
		||||
        hash_counts.extend(threshold - 1..=threshold + target);
 | 
			
		||||
 | 
			
		||||
        for hash_count in hash_counts {
 | 
			
		||||
            let hashes: Vec<_> = (0..hash_count)
 | 
			
		||||
                .into_iter()
 | 
			
		||||
                .map(|_| Hash::new_unique())
 | 
			
		||||
                .collect();
 | 
			
		||||
 | 
			
		||||
            test_hashing(hashes, FANOUT);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user