From a53dd611c0dcf4ade61ee443f1f9cba337c4cdec Mon Sep 17 00:00:00 2001 From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com> Date: Fri, 31 Dec 2021 16:47:45 +0000 Subject: [PATCH] uses enum for shred type (backport #21333) (#22147) * uses enum for shred type Current code is using u8 which does not have any type-safety and can contain invalid values: https://github.com/solana-labs/solana/blob/66fa062f1/ledger/src/shred.rs#L167 Checks for invalid shred-types are scattered through the code: https://github.com/solana-labs/solana/blob/66fa062f1/ledger/src/blockstore.rs#L849-L851 https://github.com/solana-labs/solana/blob/66fa062f1/ledger/src/shred.rs#L346-L348 The commit uses enum for shred type with #[repr(u8)]. Backward compatibility is maintained by implementing Serialize and Deserialize compatible with u8, and adding a test to assert that. (cherry picked from commit 57057f8d39b119e216648cd1c52fceb397f1903f) # Conflicts: # core/src/retransmit_stage.rs # gossip/src/cluster_info.rs # ledger/Cargo.toml # ledger/src/blockstore.rs # ledger/src/shred.rs * changes Blockstore::is_shred_duplicate arg type to ShredType (cherry picked from commit 48dfdfb4d55d8cc36aa38e5cdff2de3d18ac9385) # Conflicts: # ledger/src/blockstore.rs * removes backport merge conflicts Co-authored-by: behzad nouri --- Cargo.lock | 2 + core/src/retransmit_stage.rs | 10 +- core/src/window_service.rs | 13 ++- gossip/src/cluster_info.rs | 2 +- gossip/src/duplicate_shred.rs | 17 +--- ledger/Cargo.toml | 2 + ledger/src/blockstore.rs | 145 ++++++++++++++-------------- ledger/src/shred.rs | 172 ++++++++++++++++++++++------------ 8 files changed, 204 insertions(+), 159 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c247d6db9a..8bea4dd003 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4996,6 +4996,8 @@ dependencies = [ "libc", "log 0.4.14", "matches", + "num-derive", + "num-traits", "num_cpus", "prost", "rand 0.7.3", diff --git a/core/src/retransmit_stage.rs b/core/src/retransmit_stage.rs index 1ba0033658..5d91dc051b 100644 --- a/core/src/retransmit_stage.rs +++ b/core/src/retransmit_stage.rs @@ -18,7 +18,9 @@ use { solana_client::rpc_response::SlotUpdate, solana_gossip::cluster_info::{ClusterInfo, DATA_PLANE_FANOUT}, solana_ledger::{ - blockstore::Blockstore, leader_schedule_cache::LeaderScheduleCache, shred::Shred, + blockstore::Blockstore, + leader_schedule_cache::LeaderScheduleCache, + shred::{Shred, ShredType}, }, solana_measure::measure::Measure, solana_perf::packet::PacketBatch, @@ -137,14 +139,14 @@ impl RetransmitStats { } } -// Map of shred (slot, index, is_data) => list of hash values seen for that key. -type ShredFilter = LruCache<(Slot, u32, bool), Vec>; +// Map of shred (slot, index, type) => list of hash values seen for that key. +type ShredFilter = LruCache<(Slot, u32, ShredType), Vec>; type ShredFilterAndHasher = (ShredFilter, PacketHasher); // Returns true if shred is already received and should skip retransmit. fn should_skip_retransmit(shred: &Shred, shreds_received: &Mutex) -> bool { - let key = (shred.slot(), shred.index(), shred.is_data()); + let key = (shred.slot(), shred.index(), shred.shred_type()); let mut shreds_received = shreds_received.lock().unwrap(); let (cache, hasher) = shreds_received.deref_mut(); match cache.get_mut(&key) { diff --git a/core/src/window_service.rs b/core/src/window_service.rs index cb7f35dd79..5308055646 100644 --- a/core/src/window_service.rs +++ b/core/src/window_service.rs @@ -18,7 +18,7 @@ use { solana_ledger::{ blockstore::{self, Blockstore, BlockstoreInsertionMetrics, MAX_DATA_SHREDS_PER_SLOT}, leader_schedule_cache::LeaderScheduleCache, - shred::{Nonce, Shred}, + shred::{Nonce, Shred, ShredType}, }, solana_measure::measure::Measure, solana_metrics::{inc_new_counter_debug, inc_new_counter_error}, @@ -161,12 +161,11 @@ impl ReceiveWindowStats { } fn verify_shred_slot(shred: &Shred, root: u64) -> bool { - if shred.is_data() { + match shred.shred_type() { // Only data shreds have parent information - blockstore::verify_shred_slots(shred.slot(), shred.parent(), root) - } else { + ShredType::Data => blockstore::verify_shred_slots(shred.slot(), shred.parent(), root), // Filter out outdated coding shreds - shred.slot() >= root + ShredType::Code => shred.slot() >= root, } } @@ -218,8 +217,8 @@ fn run_check_duplicate( if let Some(existing_shred_payload) = blockstore.is_shred_duplicate( shred_slot, shred.index(), - &shred.payload, - shred.is_data(), + shred.payload.clone(), + shred.shred_type(), ) { cluster_info.push_duplicate_shred(&shred, &existing_shred_payload)?; blockstore.store_duplicate_slot( diff --git a/gossip/src/cluster_info.rs b/gossip/src/cluster_info.rs index b191d6fe15..2a92aa5587 100644 --- a/gossip/src/cluster_info.rs +++ b/gossip/src/cluster_info.rs @@ -257,7 +257,7 @@ pub fn make_accounts_hashes_message( pub(crate) type Ping = ping_pong::Ping<[u8; GOSSIP_PING_TOKEN_SIZE]>; // TODO These messages should go through the gpu pipeline for spam filtering -#[frozen_abi(digest = "D2ebKKmm6EQ8JJjYc3xUpzpBTJguqgEzShhj9fiUcP6F")] +#[frozen_abi(digest = "7cgH6JHdpxMSuPs6LEZzV5ShLXQMcZftb95s5PZKR5qB")] #[derive(Serialize, Deserialize, Debug, AbiEnumVisitor, AbiExample)] #[allow(clippy::large_enum_variant)] pub(crate) enum Protocol { diff --git a/gossip/src/duplicate_shred.rs b/gossip/src/duplicate_shred.rs index d430f45ecf..15d51c116d 100644 --- a/gossip/src/duplicate_shred.rs +++ b/gossip/src/duplicate_shred.rs @@ -89,7 +89,7 @@ fn check_shreds( // TODO: Should also allow two coding shreds with different indices but // same fec-set-index and mismatching erasure-config. Err(Error::ShredIndexMismatch) - } else if shred1.common_header.shred_type != shred2.common_header.shred_type { + } else if shred1.shred_type() != shred2.shred_type() { Err(Error::ShredTypeMismatch) } else if shred1.payload == shred2.payload { Err(Error::InvalidDuplicateShreds) @@ -119,11 +119,7 @@ pub fn from_duplicate_slot_proof( let shred1 = Shred::new_from_serialized_shred(proof.shred1.clone())?; let shred2 = Shred::new_from_serialized_shred(proof.shred2.clone())?; check_shreds(leader_schedule, &shred1, &shred2)?; - let (slot, shred_index, shred_type) = ( - shred1.slot(), - shred1.index(), - shred1.common_header.shred_type, - ); + let (slot, shred_index, shred_type) = (shred1.slot(), shred1.index(), shred1.shred_type()); let data = bincode::serialize(proof)?; let chunk_size = if DUPLICATE_SHRED_HEADER_SIZE < max_size { max_size - DUPLICATE_SHRED_HEADER_SIZE @@ -161,8 +157,7 @@ pub(crate) fn from_shred( } let other_shred = Shred::new_from_serialized_shred(other_payload.clone())?; check_shreds(leader_schedule, &shred, &other_shred)?; - let (slot, shred_index, shred_type) = - (shred.slot(), shred.index(), shred.common_header.shred_type); + let (slot, shred_index, shred_type) = (shred.slot(), shred.index(), shred.shred_type()); let proof = DuplicateSlotProof { shred1: shred.payload, shred2: other_payload, @@ -262,9 +257,7 @@ pub fn into_shreds( Err(Error::SlotMismatch) } else if shred1.index() != shred_index || shred2.index() != shred_index { Err(Error::ShredIndexMismatch) - } else if shred1.common_header.shred_type != shred_type - || shred2.common_header.shred_type != shred_type - { + } else if shred1.shred_type() != shred_type || shred2.shred_type() != shred_type { Err(Error::ShredTypeMismatch) } else if shred1.payload == shred2.payload { Err(Error::InvalidDuplicateShreds) @@ -306,7 +299,7 @@ pub(crate) mod tests { wallclock: u64::MAX, slot: Slot::MAX, shred_index: u32::MAX, - shred_type: ShredType(u8::MAX), + shred_type: ShredType::Data, num_chunks: u8::MAX, chunk_index: u8::MAX, chunk: Vec::default(), diff --git a/ledger/Cargo.toml b/ledger/Cargo.toml index 68fe8ab67c..2d3610c8ad 100644 --- a/ledger/Cargo.toml +++ b/ledger/Cargo.toml @@ -25,6 +25,8 @@ itertools = "0.9.0" lazy_static = "1.4.0" libc = "0.2.81" log = { version = "0.4.11" } +num-derive = "0.3" +num-traits = "0.2" num_cpus = "1.13.0" prost = "0.8.0" rand = "0.7.0" diff --git a/ledger/src/blockstore.rs b/ledger/src/blockstore.rs index 038f91966f..fffbab9330 100644 --- a/ledger/src/blockstore.rs +++ b/ledger/src/blockstore.rs @@ -14,7 +14,10 @@ use { erasure::ErasureConfig, leader_schedule_cache::LeaderScheduleCache, next_slots_iterator::NextSlotsIterator, - shred::{Result as ShredResult, Shred, Shredder, MAX_DATA_SHREDS_PER_FEC_BLOCK}, + shred::{ + Result as ShredResult, Shred, ShredType, Shredder, MAX_DATA_SHREDS_PER_FEC_BLOCK, + SHRED_PAYLOAD_SIZE, + }, }, bincode::deserialize, log::*, @@ -827,51 +830,54 @@ impl Blockstore { let mut newly_completed_data_sets: Vec = vec![]; let mut inserted_indices = Vec::new(); for (i, (shred, is_repaired)) in shreds.into_iter().zip(is_repaired).enumerate() { - if shred.is_data() { - let shred_source = if is_repaired { - ShredSource::Repaired - } else { - ShredSource::Turbine - }; - match self.check_insert_data_shred( - shred, - &mut erasure_metas, - &mut index_working_set, - &mut slot_meta_working_set, - &mut write_batch, - &mut just_inserted_data_shreds, - &mut index_meta_time, - is_trusted, - handle_duplicate, - leader_schedule, - shred_source, - ) { - Err(InsertDataShredError::Exists) => metrics.num_data_shreds_exists += 1, - Err(InsertDataShredError::InvalidShred) => metrics.num_data_shreds_invalid += 1, - Err(InsertDataShredError::BlockstoreError(_)) => { - metrics.num_data_shreds_blockstore_error += 1; - } - Ok(completed_data_sets) => { - newly_completed_data_sets.extend(completed_data_sets); - inserted_indices.push(i); - metrics.num_inserted += 1; - } - }; - } else if shred.is_code() { - self.check_cache_coding_shred( - shred, - &mut erasure_metas, - &mut index_working_set, - &mut just_inserted_coding_shreds, - &mut index_meta_time, - handle_duplicate, - is_trusted, - is_repaired, - metrics, - ); - } else { - panic!("There should be no other case"); - } + match shred.shred_type() { + ShredType::Data => { + let shred_source = if is_repaired { + ShredSource::Repaired + } else { + ShredSource::Turbine + }; + match self.check_insert_data_shred( + shred, + &mut erasure_metas, + &mut index_working_set, + &mut slot_meta_working_set, + &mut write_batch, + &mut just_inserted_data_shreds, + &mut index_meta_time, + is_trusted, + handle_duplicate, + leader_schedule, + shred_source, + ) { + Err(InsertDataShredError::Exists) => metrics.num_data_shreds_exists += 1, + Err(InsertDataShredError::InvalidShred) => { + metrics.num_data_shreds_invalid += 1 + } + Err(InsertDataShredError::BlockstoreError(_)) => { + metrics.num_data_shreds_blockstore_error += 1; + } + Ok(completed_data_sets) => { + newly_completed_data_sets.extend(completed_data_sets); + inserted_indices.push(i); + metrics.num_inserted += 1; + } + }; + } + ShredType::Code => { + self.check_cache_coding_shred( + shred, + &mut erasure_metas, + &mut index_working_set, + &mut just_inserted_coding_shreds, + &mut index_meta_time, + handle_duplicate, + is_trusted, + is_repaired, + metrics, + ); + } + }; } start.stop(); @@ -1345,7 +1351,6 @@ impl Blockstore { leader_schedule: Option<&LeaderScheduleCache>, shred_source: ShredSource, ) -> bool { - use crate::shred::SHRED_PAYLOAD_SIZE; let shred_index = u64::from(shred.index()); let slot = shred.slot(); let last_in_slot = if shred.last_in_slot() { @@ -1574,7 +1579,6 @@ impl Blockstore { } pub fn get_data_shred(&self, slot: Slot, index: u64) -> Result>> { - use crate::shred::SHRED_PAYLOAD_SIZE; self.data_shred_cf.get_bytes((slot, index)).map(|data| { data.map(|mut d| { // Only data_header.size bytes stored in the blockstore so @@ -3032,31 +3036,18 @@ impl Blockstore { &self, slot: u64, index: u32, - new_shred_raw: &[u8], - is_data: bool, + mut payload: Vec, + shred_type: ShredType, ) -> Option> { - let res = if is_data { - self.get_data_shred(slot, index as u64) - .expect("fetch from DuplicateSlots column family failed") - } else { - self.get_coding_shred(slot, index as u64) - .expect("fetch from DuplicateSlots column family failed") - }; - - let mut payload = new_shred_raw.to_vec(); - payload.resize( - std::cmp::max(new_shred_raw.len(), crate::shred::SHRED_PAYLOAD_SIZE), - 0, - ); + let existing_shred = match shred_type { + ShredType::Data => self.get_data_shred(slot, index as u64), + ShredType::Code => self.get_coding_shred(slot, index as u64), + } + .expect("fetch from DuplicateSlots column family failed")?; + let size = payload.len().max(SHRED_PAYLOAD_SIZE); + payload.resize(size, 0u8); let new_shred = Shred::new_from_serialized_shred(payload).unwrap(); - res.map(|existing_shred| { - if existing_shred != new_shred.payload { - Some(existing_shred) - } else { - None - } - }) - .unwrap_or(None) + (existing_shred != new_shred.payload).then(|| existing_shred) } pub fn has_duplicate_shreds_in_slot(&self, slot: Slot) -> bool { @@ -8228,8 +8219,8 @@ pub mod tests { blockstore.is_shred_duplicate( slot, 0, - &duplicate_shred.payload, - duplicate_shred.is_data() + duplicate_shred.payload.clone(), + duplicate_shred.shred_type() ), Some(shred.payload.to_vec()) ); @@ -8237,8 +8228,8 @@ pub mod tests { .is_shred_duplicate( slot, 0, - &non_duplicate_shred.payload, - duplicate_shred.is_data() + non_duplicate_shred.payload, + duplicate_shred.shred_type() ) .is_none()); @@ -8726,8 +8717,8 @@ pub mod tests { .is_shred_duplicate( slot, even_smaller_last_shred_duplicate.index(), - &even_smaller_last_shred_duplicate.payload, - true + even_smaller_last_shred_duplicate.payload.clone(), + ShredType::Data, ) .is_some()); blockstore diff --git a/ledger/src/shred.rs b/ledger/src/shred.rs index 39419bf569..f758a79907 100644 --- a/ledger/src/shred.rs +++ b/ledger/src/shred.rs @@ -57,12 +57,14 @@ use { }, bincode::config::Options, core::cell::RefCell, + num_derive::FromPrimitive, + num_traits::FromPrimitive, rayon::{ iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, slice::ParallelSlice, ThreadPool, }, - serde::{Deserialize, Serialize}, + serde::{Deserialize, Deserializer, Serialize, Serializer}, solana_measure::measure::Measure, solana_perf::packet::{limited_deserialize, Packet}, solana_rayon_threadlimit::get_thread_count, @@ -144,10 +146,6 @@ thread_local!(static PAR_THREAD_POOL: RefCell = RefCell::new(rayon:: .build() .unwrap())); -/// The constants that define if a shred is data or coding -pub const DATA_SHRED: u8 = 0b1010_0101; -pub const CODING_SHRED: u8 = 0b0101_1010; - pub const MAX_DATA_SHREDS_PER_FEC_BLOCK: u32 = 32; pub const SHRED_TICK_REFERENCE_MASK: u8 = 0b0011_1111; @@ -176,11 +174,36 @@ pub enum ShredError { pub type Result = std::result::Result; -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, AbiExample, Deserialize, Serialize)] -pub struct ShredType(pub u8); +#[repr(u8)] +#[derive(Copy, Clone, Debug, Eq, FromPrimitive, Hash, PartialEq, AbiEnumVisitor, AbiExample)] +pub enum ShredType { + Data = 0b1010_0101, + Code = 0b0101_1010, +} + impl Default for ShredType { fn default() -> Self { - ShredType(DATA_SHRED) + ShredType::Data + } +} + +impl Serialize for ShredType { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + (*self as u8).serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for ShredType { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + let shred_type = u8::deserialize(deserializer)?; + Self::from_u8(shred_type) + .ok_or_else(|| serde::de::Error::custom(ShredError::InvalidShredType)) } } @@ -331,32 +354,33 @@ impl Shred { // so that erasure generation/recovery works correctly // But only the data_header.size is stored in blockstore. payload.resize(SHRED_PAYLOAD_SIZE, 0); - let shred = if common_header.shred_type == ShredType(CODING_SHRED) { - let coding_header: CodingShredHeader = - Self::deserialize_obj(&mut start, SIZE_OF_CODING_SHRED_HEADER, &payload)?; - Self { - common_header, - data_header: DataShredHeader::default(), - coding_header, - payload, + let shred = match common_header.shred_type { + ShredType::Code => { + let coding_header: CodingShredHeader = + Self::deserialize_obj(&mut start, SIZE_OF_CODING_SHRED_HEADER, &payload)?; + Self { + common_header, + data_header: DataShredHeader::default(), + coding_header, + payload, + } } - } else if common_header.shred_type == ShredType(DATA_SHRED) { - let data_header: DataShredHeader = - Self::deserialize_obj(&mut start, SIZE_OF_DATA_SHRED_HEADER, &payload)?; - if u64::from(data_header.parent_offset) > common_header.slot { - return Err(ShredError::InvalidParentOffset { - slot, - parent_offset: data_header.parent_offset, - }); + ShredType::Data => { + let data_header: DataShredHeader = + Self::deserialize_obj(&mut start, SIZE_OF_DATA_SHRED_HEADER, &payload)?; + if u64::from(data_header.parent_offset) > common_header.slot { + return Err(ShredError::InvalidParentOffset { + slot, + parent_offset: data_header.parent_offset, + }); + } + Self { + common_header, + data_header, + coding_header: CodingShredHeader::default(), + payload, + } } - Self { - common_header, - data_header, - coding_header: CodingShredHeader::default(), - payload, - } - } else { - return Err(ShredError::InvalidShredType); }; Ok(shred) @@ -397,23 +421,22 @@ impl Shred { &common_header, ) .expect("Failed to write header into shred buffer"); - if common_header.shred_type == ShredType(DATA_SHRED) { - Self::serialize_obj_into( + match common_header.shred_type { + ShredType::Data => Self::serialize_obj_into( &mut start, SIZE_OF_DATA_SHRED_HEADER, &mut payload, &data_header, ) - .expect("Failed to write data header into shred buffer"); - } else if common_header.shred_type == ShredType(CODING_SHRED) { - Self::serialize_obj_into( + .expect("Failed to write data header into shred buffer"), + ShredType::Code => Self::serialize_obj_into( &mut start, SIZE_OF_CODING_SHRED_HEADER, &mut payload, &coding_header, ) - .expect("Failed to write data header into shred buffer"); - } + .expect("Failed to write coding header into shred buffer"), + }; Shred { common_header, data_header, @@ -434,6 +457,7 @@ impl Shred { self.common_header.slot } + // TODO: This should return Option pub fn parent(&self) -> Slot { if self.is_data() { self.common_header.slot - u64::from(self.data_header.parent_offset) @@ -491,11 +515,16 @@ impl Shred { } } + #[inline] + pub fn shred_type(&self) -> ShredType { + self.common_header.shred_type + } + pub fn is_data(&self) -> bool { - self.common_header.shred_type == ShredType(DATA_SHRED) + self.shred_type() == ShredType::Data } pub fn is_code(&self) -> bool { - self.common_header.shred_type == ShredType(CODING_SHRED) + self.shred_type() == ShredType::Code } pub fn last_in_slot(&self) -> bool { @@ -769,7 +798,7 @@ impl Shredder { version: u16, ) -> (ShredCommonHeader, CodingShredHeader) { let header = ShredCommonHeader { - shred_type: ShredType(CODING_SHRED), + shred_type: ShredType::Code, index, slot, version, @@ -992,7 +1021,7 @@ pub struct ShredFetchStats { pub fn get_shred_slot_index_type( p: &Packet, stats: &mut ShredFetchStats, -) -> Option<(Slot, u32, bool)> { +) -> Option<(Slot, u32, ShredType)> { let index_start = OFFSET_OF_SHRED_INDEX; let index_end = index_start + SIZE_OF_SHRED_INDEX; let slot_start = OFFSET_OF_SHRED_SLOT; @@ -1031,14 +1060,14 @@ pub fn get_shred_slot_index_type( } } - let shred_type = p.data[OFFSET_OF_SHRED_TYPE]; - if shred_type == DATA_SHRED || shred_type == CODING_SHRED { - return Some((slot, index, shred_type == DATA_SHRED)); - } else { - stats.bad_shred_type += 1; - } - - None + let shred_type = match ShredType::from_u8(p.data[OFFSET_OF_SHRED_TYPE]) { + None => { + stats.bad_shred_type += 1; + return None; + } + Some(shred_type) => shred_type, + }; + Some((slot, index, shred_type)) } pub fn max_ticks_per_n_shreds(num_shreds: u64, shred_data_size: Option) -> u64 { @@ -1195,7 +1224,7 @@ pub mod tests { let mut data_shred_indexes = HashSet::new(); let mut coding_shred_indexes = HashSet::new(); for shred in data_shreds.iter() { - assert_eq!(shred.common_header.shred_type, ShredType(DATA_SHRED)); + assert_eq!(shred.shred_type(), ShredType::Data); let index = shred.common_header.index; let is_last = index as u64 == num_expected_data_shreds - 1; verify_test_data_shred( @@ -1214,7 +1243,7 @@ pub mod tests { for shred in coding_shreds.iter() { let index = shred.common_header.index; - assert_eq!(shred.common_header.shred_type, ShredType(CODING_SHRED)); + assert_eq!(shred.shred_type(), ShredType::Code); verify_test_code_shred(shred, index, slot, &keypair.pubkey(), true); assert!(!coding_shred_indexes.contains(&index)); coding_shred_indexes.insert(index); @@ -1816,7 +1845,7 @@ pub mod tests { shred.copy_to_packet(&mut packet); let mut stats = ShredFetchStats::default(); let ret = get_shred_slot_index_type(&packet, &mut stats); - assert_eq!(Some((1, 3, true)), ret); + assert_eq!(Some((1, 3, ShredType::Data)), ret); assert_eq!(stats, ShredFetchStats::default()); packet.meta.size = OFFSET_OF_SHRED_TYPE; @@ -1837,7 +1866,7 @@ pub mod tests { packet.meta.size = OFFSET_OF_SHRED_INDEX + SIZE_OF_SHRED_INDEX; assert_eq!( - Some((1, 3, true)), + Some((1, 3, ShredType::Data)), get_shred_slot_index_type(&packet, &mut stats) ); assert_eq!(stats.index_overrun, 4); @@ -1853,7 +1882,7 @@ pub mod tests { ); shred.copy_to_packet(&mut packet); assert_eq!( - Some((8, 2, false)), + Some((8, 2, ShredType::Code)), get_shred_slot_index_type(&packet, &mut stats) ); @@ -1862,7 +1891,7 @@ pub mod tests { assert_eq!(None, get_shred_slot_index_type(&packet, &mut stats)); assert_eq!(1, stats.index_out_of_bounds); - let (mut header, coding_header) = Shredder::new_coding_shred_header( + let (header, coding_header) = Shredder::new_coding_shred_header( 8, // slot 2, // index 10, // fec_set_index @@ -1871,11 +1900,38 @@ pub mod tests { 3, // position 200, // version ); - header.shred_type = ShredType(u8::MAX); let shred = Shred::new_empty_from_header(header, DataShredHeader::default(), coding_header); shred.copy_to_packet(&mut packet); + packet.data[OFFSET_OF_SHRED_TYPE] = u8::MAX; assert_eq!(None, get_shred_slot_index_type(&packet, &mut stats)); assert_eq!(1, stats.bad_shred_type); } + + // Asserts that ShredType is backward compatible with u8. + #[test] + fn test_shred_type_compat() { + assert_eq!(std::mem::size_of::(), std::mem::size_of::()); + assert_eq!(ShredType::from_u8(0), None); + assert_eq!(ShredType::from_u8(1), None); + assert_matches!(bincode::deserialize::(&[0u8]), Err(_)); + // data shred + assert_eq!(ShredType::Data as u8, 0b1010_0101); + assert_eq!(ShredType::from_u8(0b1010_0101), Some(ShredType::Data)); + let buf = bincode::serialize(&ShredType::Data).unwrap(); + assert_eq!(buf, vec![0b1010_0101]); + assert_matches!( + bincode::deserialize::(&[0b1010_0101]), + Ok(ShredType::Data) + ); + // coding shred + assert_eq!(ShredType::Code as u8, 0b0101_1010); + assert_eq!(ShredType::from_u8(0b0101_1010), Some(ShredType::Code)); + let buf = bincode::serialize(&ShredType::Code).unwrap(); + assert_eq!(buf, vec![0b0101_1010]); + assert_matches!( + bincode::deserialize::(&[0b0101_1010]), + Ok(ShredType::Code) + ); + } }