diff --git a/core/src/cluster_info.rs b/core/src/cluster_info.rs index 41be1131de..7c387d3815 100644 --- a/core/src/cluster_info.rs +++ b/core/src/cluster_info.rs @@ -12,8 +12,6 @@ //! * layer 2 - Everyone else, if layer 1 is `2^10`, layer 2 should be able to fit `2^20` number of nodes. //! //! Bank needs to provide an interface for us to query the stake weight -use crate::crds_value::CompressionType::*; -use crate::crds_value::EpochIncompleteSlots; use crate::packet::limited_deserialize; use crate::streamer::{PacketReceiver, PacketSender}; use crate::{ @@ -21,7 +19,9 @@ use crate::{ crds_gossip::CrdsGossip, crds_gossip_error::CrdsGossipError, crds_gossip_pull::{CrdsFilter, CRDS_GOSSIP_PULL_CRDS_TIMEOUT_MS}, - crds_value::{self, CrdsData, CrdsValue, CrdsValueLabel, EpochSlots, SnapshotHash, Vote}, + crds_value::{ + self, CrdsData, CrdsValue, CrdsValueLabel, EpochSlots, SnapshotHash, Vote, MAX_WALLCLOCK, + }, packet::{Packet, PACKET_DATA_SIZE}, result::{Error, Result}, sendmmsg::{multicast, send_mmsg}, @@ -31,9 +31,9 @@ use crate::{ use rand::distributions::{Distribution, WeightedIndex}; use rand::SeedableRng; use rand_chacha::ChaChaRng; +use solana_sdk::sanitize::{Sanitize, SanitizeError}; use bincode::{serialize, serialized_size}; -use compression::prelude::*; use core::cmp; use itertools::Itertools; use rayon::iter::IntoParallelIterator; @@ -87,9 +87,6 @@ const MAX_PROTOCOL_HEADER_SIZE: u64 = 214; /// 128MB/PACKET_DATA_SIZE const MAX_GOSSIP_TRAFFIC: usize = 128_000_000 / PACKET_DATA_SIZE; -const NUM_BITS_PER_BYTE: u64 = 8; -const MIN_SIZE_TO_COMPRESS_GZIP: u64 = 64; - /// Keep the number of snapshot hashes a node publishes under MAX_PROTOCOL_PAYLOAD_SIZE pub const MAX_SNAPSHOT_HASHES: usize = 16; @@ -157,6 +154,15 @@ pub struct PruneData { pub wallclock: u64, } +impl Sanitize for PruneData { + fn sanitize(&self) -> std::result::Result<(), SanitizeError> { + if self.wallclock >= MAX_WALLCLOCK { + return Err(SanitizeError::ValueOutOfRange); + } + Ok(()) + } +} + impl Signable for PruneData { fn pubkey(&self) -> Pubkey { self.pubkey @@ -221,6 +227,20 @@ enum Protocol { PruneMessage(Pubkey, PruneData), } +impl Sanitize for Protocol { + fn sanitize(&self) -> std::result::Result<(), SanitizeError> { + match self { + Protocol::PullRequest(filter, val) => { + filter.sanitize()?; + val.sanitize() + } + Protocol::PullResponse(_, val) => val.sanitize(), + Protocol::PushMessage(_, val) => val.sanitize(), + Protocol::PruneMessage(_, val) => val.sanitize(), + } + } +} + // Rating for pull requests // A response table is generated as a // 2-d table arranged by target nodes and a @@ -373,115 +393,17 @@ impl ClusterInfo { ) } - pub fn compress_incomplete_slots(incomplete_slots: &BTreeSet) -> EpochIncompleteSlots { - if !incomplete_slots.is_empty() { - let first_slot = incomplete_slots - .iter() - .next() - .expect("expected to find at least one slot"); - let last_slot = incomplete_slots - .iter() - .next_back() - .expect("expected to find last slot"); - let num_uncompressed_bits = last_slot.saturating_sub(*first_slot) + 1; - let num_uncompressed_bytes = if num_uncompressed_bits % NUM_BITS_PER_BYTE > 0 { - 1 - } else { - 0 - } + num_uncompressed_bits / NUM_BITS_PER_BYTE; - let mut uncompressed = vec![0u8; num_uncompressed_bytes as usize]; - incomplete_slots.iter().for_each(|slot| { - let offset_from_first_slot = slot.saturating_sub(*first_slot); - let index = offset_from_first_slot / NUM_BITS_PER_BYTE; - let bit_index = offset_from_first_slot % NUM_BITS_PER_BYTE; - uncompressed[index as usize] |= 1 << bit_index; - }); - if num_uncompressed_bytes >= MIN_SIZE_TO_COMPRESS_GZIP { - if let Ok(compressed) = uncompressed - .iter() - .cloned() - .encode(&mut GZipEncoder::new(), Action::Finish) - .collect::, _>>() - { - return EpochIncompleteSlots { - first: *first_slot, - compression: GZip, - compressed_list: compressed, - }; - } - } else { - return EpochIncompleteSlots { - first: *first_slot, - compression: Uncompressed, - compressed_list: uncompressed, - }; - } - } - EpochIncompleteSlots::default() - } - - fn bitmap_to_slot_list(first: Slot, bitmap: &[u8]) -> BTreeSet { - let mut old_incomplete_slots: BTreeSet = BTreeSet::new(); - bitmap.iter().enumerate().for_each(|(i, val)| { - if *val != 0 { - (0..8).for_each(|bit_index| { - if (1 << bit_index & *val) != 0 { - let slot = first + i as u64 * NUM_BITS_PER_BYTE + bit_index as u64; - old_incomplete_slots.insert(slot); - } - }) - } - }); - old_incomplete_slots - } - - pub fn decompress_incomplete_slots(slots: &EpochIncompleteSlots) -> BTreeSet { - match slots.compression { - Uncompressed => Self::bitmap_to_slot_list(slots.first, &slots.compressed_list), - GZip => { - if let Ok(decompressed) = slots - .compressed_list - .iter() - .cloned() - .decode(&mut GZipDecoder::new()) - .collect::, _>>() - { - Self::bitmap_to_slot_list(slots.first, &decompressed) - } else { - BTreeSet::new() - } - } - BZip2 => { - if let Ok(decompressed) = slots - .compressed_list - .iter() - .cloned() - .decode(&mut BZip2Decoder::new()) - .collect::, _>>() - { - Self::bitmap_to_slot_list(slots.first, &decompressed) - } else { - BTreeSet::new() - } - } - } - } - pub fn push_epoch_slots( &mut self, id: Pubkey, - root: Slot, + _root: Slot, min: Slot, - slots: BTreeSet, - incomplete_slots: &BTreeSet, + _slots: BTreeSet, + _incomplete_slots: &BTreeSet, ) { - let compressed = Self::compress_incomplete_slots(incomplete_slots); let now = timestamp(); let entry = CrdsValue::new_signed( - CrdsData::EpochSlots( - 0, - EpochSlots::new(id, root, min, slots, vec![compressed], now), - ), + CrdsData::EpochSlots(0, EpochSlots::new(id, min, now)), &self.keypair, ); self.gossip @@ -1358,6 +1280,7 @@ impl ClusterInfo { let from_addr = packet.meta.addr(); limited_deserialize(&packet.data[..packet.meta.size]) .into_iter() + .filter(|r: &Protocol| r.sanitize().is_ok()) .for_each(|request| match request { Protocol::PullRequest(filter, caller) => { let start = allocated.get(); @@ -2500,14 +2423,7 @@ mod tests { } let value = CrdsValue::new_unsigned(CrdsData::EpochSlots( 0, - EpochSlots { - from: Pubkey::default(), - root: 0, - lowest: 0, - slots: btree_slots, - stash: vec![], - wallclock: 0, - }, + EpochSlots::new(Pubkey::default(), 0, 0), )); test_split_messages(value); } @@ -2519,39 +2435,19 @@ mod tests { let payload: Vec = vec![]; let vec_size = serialized_size(&payload).unwrap(); let desired_size = MAX_PROTOCOL_PAYLOAD_SIZE - vec_size; - let mut value = CrdsValue::new_unsigned(CrdsData::EpochSlots( - 0, - EpochSlots { - from: Pubkey::default(), - root: 0, - lowest: 0, - slots: BTreeSet::new(), - stash: vec![], - wallclock: 0, - }, - )); + let mut value = CrdsValue::new_unsigned(CrdsData::SnapshotHashes(SnapshotHash { + from: Pubkey::default(), + hashes: vec![], + wallclock: 0, + })); let mut i = 0; while value.size() <= desired_size { - let slots = (0..i).collect::>(); - if slots.len() > 200 { - panic!( - "impossible to match size: last {:?} vs desired {:?}", - serialized_size(&value).unwrap(), - desired_size - ); - } - value.data = CrdsData::EpochSlots( - 0, - EpochSlots { - from: Pubkey::default(), - root: 0, - lowest: 0, - slots, - stash: vec![], - wallclock: 0, - }, - ); + value.data = CrdsData::SnapshotHashes(SnapshotHash { + from: Pubkey::default(), + hashes: vec![(0, Hash::default()); i], + wallclock: 0, + }); i += 1; } let split = ClusterInfo::split_gossip_messages(vec![value.clone()]); @@ -2681,11 +2577,9 @@ mod tests { node_keypair, ); for i in 0..10 { - let mut peer_root = 5; let mut peer_lowest = 0; if i >= 5 { // make these invalid for the upcoming repair request - peer_root = 15; peer_lowest = 10; } let other_node_pubkey = Pubkey::new_rand(); @@ -2693,14 +2587,7 @@ mod tests { cluster_info.insert_info(other_node.clone()); let value = CrdsValue::new_unsigned(CrdsData::EpochSlots( 0, - EpochSlots::new( - other_node_pubkey, - peer_root, - peer_lowest, - BTreeSet::new(), - vec![], - timestamp(), - ), + EpochSlots::new(other_node_pubkey, peer_lowest, timestamp()), )); let _ = cluster_info.gossip.crds.insert(value, timestamp()); } @@ -2749,6 +2636,14 @@ mod tests { assert_eq!(MAX_PROTOCOL_HEADER_SIZE, max_protocol_size); } + #[test] + fn test_protocol_sanitize() { + let mut pd = PruneData::default(); + pd.wallclock = MAX_WALLCLOCK; + let msg = Protocol::PruneMessage(Pubkey::default(), pd); + assert_eq!(msg.sanitize(), Err(SanitizeError::ValueOutOfRange)); + } + // computes the maximum size for pull request blooms fn max_bloom_size() -> usize { let filter_size = serialized_size(&CrdsFilter::default()) @@ -2761,38 +2656,4 @@ mod tests { serialized_size(&protocol).expect("unable to serialize gossip protocol") as usize; PACKET_DATA_SIZE - (protocol_size - filter_size) } - - #[test] - fn test_compress_incomplete_slots() { - let mut incomplete_slots: BTreeSet = BTreeSet::new(); - - assert_eq!( - EpochIncompleteSlots::default(), - ClusterInfo::compress_incomplete_slots(&incomplete_slots) - ); - - incomplete_slots.insert(100); - let compressed = ClusterInfo::compress_incomplete_slots(&incomplete_slots); - assert_eq!(100, compressed.first); - let decompressed = ClusterInfo::decompress_incomplete_slots(&compressed); - assert_eq!(incomplete_slots, decompressed); - - incomplete_slots.insert(104); - let compressed = ClusterInfo::compress_incomplete_slots(&incomplete_slots); - assert_eq!(100, compressed.first); - let decompressed = ClusterInfo::decompress_incomplete_slots(&compressed); - assert_eq!(incomplete_slots, decompressed); - - incomplete_slots.insert(80); - let compressed = ClusterInfo::compress_incomplete_slots(&incomplete_slots); - assert_eq!(80, compressed.first); - let decompressed = ClusterInfo::decompress_incomplete_slots(&compressed); - assert_eq!(incomplete_slots, decompressed); - - incomplete_slots.insert(10000); - let compressed = ClusterInfo::compress_incomplete_slots(&incomplete_slots); - assert_eq!(80, compressed.first); - let decompressed = ClusterInfo::decompress_incomplete_slots(&compressed); - assert_eq!(incomplete_slots, decompressed); - } } diff --git a/core/src/contact_info.rs b/core/src/contact_info.rs index 44fa9dd793..d77ba52d80 100644 --- a/core/src/contact_info.rs +++ b/core/src/contact_info.rs @@ -1,6 +1,8 @@ +use crate::crds_value::MAX_WALLCLOCK; use solana_sdk::pubkey::Pubkey; #[cfg(test)] use solana_sdk::rpc_port; +use solana_sdk::sanitize::{Sanitize, SanitizeError}; #[cfg(test)] use solana_sdk::signature::{Keypair, Signer}; use solana_sdk::timing::timestamp; @@ -37,6 +39,15 @@ pub struct ContactInfo { pub shred_version: u16, } +impl Sanitize for ContactInfo { + fn sanitize(&self) -> std::result::Result<(), SanitizeError> { + if self.wallclock >= MAX_WALLCLOCK { + return Err(SanitizeError::Failed); + } + Ok(()) + } +} + impl Ord for ContactInfo { fn cmp(&self, other: &Self) -> Ordering { self.id.cmp(&other.id) diff --git a/core/src/crds_gossip_pull.rs b/core/src/crds_gossip_pull.rs index af40826a21..8ad9dcb8de 100644 --- a/core/src/crds_gossip_pull.rs +++ b/core/src/crds_gossip_pull.rs @@ -37,6 +37,13 @@ pub struct CrdsFilter { mask_bits: u32, } +impl solana_sdk::sanitize::Sanitize for CrdsFilter { + fn sanitize(&self) -> std::result::Result<(), solana_sdk::sanitize::SanitizeError> { + self.filter.sanitize()?; + Ok(()) + } +} + impl CrdsFilter { pub fn new_rand(num_items: usize, max_bytes: usize) -> Self { let max_bits = (max_bytes * 8) as f64; diff --git a/core/src/crds_value.rs b/core/src/crds_value.rs index 7c8161b931..83af0fb0b1 100644 --- a/core/src/crds_value.rs +++ b/core/src/crds_value.rs @@ -1,5 +1,6 @@ use crate::contact_info::ContactInfo; use bincode::{serialize, serialized_size}; +use solana_sdk::sanitize::{Sanitize, SanitizeError}; use solana_sdk::timing::timestamp; use solana_sdk::{ clock::Slot, @@ -14,10 +15,14 @@ use std::{ fmt, }; +pub const MAX_WALLCLOCK: u64 = 1_000_000_000_000_000; +pub const MAX_SLOT: u64 = 1_000_000_000_000_000; + pub type VoteIndex = u8; pub const MAX_VOTES: VoteIndex = 32; pub type EpochSlotIndex = u8; +pub const MAX_EPOCH_SLOTS: EpochSlotIndex = 1; /// CrdsValue that is replicated across the cluster #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] @@ -26,6 +31,13 @@ pub struct CrdsValue { pub data: CrdsData, } +impl Sanitize for CrdsValue { + fn sanitize(&self) -> Result<(), SanitizeError> { + self.signature.sanitize()?; + self.data.sanitize() + } +} + impl Signable for CrdsValue { fn pubkey(&self) -> Pubkey { self.pubkey() @@ -44,14 +56,8 @@ impl Signable for CrdsValue { } fn verify(&self) -> bool { - let sig_check = self - .get_signature() - .verify(&self.pubkey().as_ref(), self.signable_data().borrow()); - let data_check = match &self.data { - CrdsData::Vote(ix, _) => *ix < MAX_VOTES, - _ => true, - }; - sig_check && data_check + self.get_signature() + .verify(&self.pubkey().as_ref(), self.signable_data().borrow()) } } @@ -87,6 +93,39 @@ pub struct EpochIncompleteSlots { pub compressed_list: Vec, } +impl Sanitize for EpochIncompleteSlots { + fn sanitize(&self) -> Result<(), SanitizeError> { + if self.first >= MAX_SLOT { + return Err(SanitizeError::Failed); + } + //rest of the data doesn't matter since we no longer decompress + //these values + Ok(()) + } +} + +impl Sanitize for CrdsData { + fn sanitize(&self) -> Result<(), SanitizeError> { + match self { + CrdsData::ContactInfo(val) => val.sanitize(), + CrdsData::Vote(ix, val) => { + if *ix >= MAX_VOTES { + return Err(SanitizeError::Failed); + } + val.sanitize() + } + CrdsData::SnapshotHashes(val) => val.sanitize(), + CrdsData::AccountsHashes(val) => val.sanitize(), + CrdsData::EpochSlots(ix, val) => { + if *ix as usize >= MAX_EPOCH_SLOTS as usize { + return Err(SanitizeError::Failed); + } + val.sanitize() + } + } + } +} + #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] pub struct SnapshotHash { pub from: Pubkey, @@ -94,6 +133,20 @@ pub struct SnapshotHash { pub wallclock: u64, } +impl Sanitize for SnapshotHash { + fn sanitize(&self) -> Result<(), SanitizeError> { + if self.wallclock >= MAX_WALLCLOCK { + return Err(SanitizeError::Failed); + } + for (slot, _) in &self.hashes { + if *slot >= MAX_SLOT { + return Err(SanitizeError::Failed); + } + } + self.from.sanitize() + } +} + impl SnapshotHash { pub fn new(from: Pubkey, hashes: Vec<(Slot, Hash)>) -> Self { Self { @@ -107,33 +160,47 @@ impl SnapshotHash { #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] pub struct EpochSlots { pub from: Pubkey, - pub root: Slot, + root: Slot, pub lowest: Slot, - pub slots: BTreeSet, - pub stash: Vec, + slots: BTreeSet, + stash: Vec, pub wallclock: u64, } impl EpochSlots { - pub fn new( - from: Pubkey, - root: Slot, - lowest: Slot, - slots: BTreeSet, - stash: Vec, - wallclock: u64, - ) -> Self { + pub fn new(from: Pubkey, lowest: Slot, wallclock: u64) -> Self { Self { from, - root, + root: 0, lowest, - slots, - stash, + slots: BTreeSet::new(), + stash: vec![], wallclock, } } } +impl Sanitize for EpochSlots { + fn sanitize(&self) -> Result<(), SanitizeError> { + if self.wallclock >= MAX_WALLCLOCK { + return Err(SanitizeError::Failed); + } + if self.lowest >= MAX_SLOT { + return Err(SanitizeError::Failed); + } + if self.root >= MAX_SLOT { + return Err(SanitizeError::Failed); + } + for slot in &self.slots { + if *slot >= MAX_SLOT { + return Err(SanitizeError::Failed); + } + } + self.stash.sanitize()?; + self.from.sanitize() + } +} + #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] pub struct Vote { pub from: Pubkey, @@ -141,6 +208,16 @@ pub struct Vote { pub wallclock: u64, } +impl Sanitize for Vote { + fn sanitize(&self) -> Result<(), SanitizeError> { + if self.wallclock >= MAX_WALLCLOCK { + return Err(SanitizeError::Failed); + } + self.from.sanitize()?; + self.transaction.sanitize() + } +} + impl Vote { pub fn new(from: &Pubkey, transaction: Transaction, wallclock: u64) -> Self { Self { @@ -356,7 +433,7 @@ mod test { let v = CrdsValue::new_unsigned(CrdsData::EpochSlots( 0, - EpochSlots::new(Pubkey::default(), 0, 0, BTreeSet::new(), vec![], 0), + EpochSlots::new(Pubkey::default(), 0, 0), )); assert_eq!(v.wallclock(), 0); let key = v.clone().epoch_slots().unwrap().from; @@ -377,10 +454,9 @@ mod test { Vote::new(&keypair.pubkey(), test_tx(), timestamp()), )); verify_signatures(&mut v, &keypair, &wrong_keypair); - let btreeset: BTreeSet = vec![1, 2, 3, 6, 8].into_iter().collect(); v = CrdsValue::new_unsigned(CrdsData::EpochSlots( 0, - EpochSlots::new(keypair.pubkey(), 0, 0, btreeset, vec![], timestamp()), + EpochSlots::new(keypair.pubkey(), 0, timestamp()), )); verify_signatures(&mut v, &keypair, &wrong_keypair); } @@ -395,9 +471,21 @@ mod test { ), &keypair, ); - assert!(!vote.verify()); + assert!(vote.sanitize().is_err()); } + #[test] + fn test_max_epoch_slots_index() { + let keypair = Keypair::new(); + let item = CrdsValue::new_signed( + CrdsData::Vote( + MAX_VOTES, + Vote::new(&keypair.pubkey(), test_tx(), timestamp()), + ), + &keypair, + ); + assert!(item.sanitize().is_err()); + } #[test] fn test_compute_vote_index_empty() { for i in 0..MAX_VOTES { diff --git a/runtime/src/bank/mod.rs b/runtime/src/bank/mod.rs index f9d3b0c1c8..f52263c384 100644 --- a/runtime/src/bank/mod.rs +++ b/runtime/src/bank/mod.rs @@ -41,6 +41,7 @@ use solana_sdk::{ inflation::Inflation, native_loader, nonce, pubkey::Pubkey, + sanitize::Sanitize, signature::{Keypair, Signature}, slot_hashes::SlotHashes, slot_history::SlotHistory, @@ -1075,7 +1076,7 @@ impl Bank { OrderedIterator::new(txs, iteration_order) .zip(lock_results) .map(|(tx, lock_res)| { - if lock_res.is_ok() && !tx.verify_refs() { + if lock_res.is_ok() && tx.sanitize().is_err() { error_counters.invalid_account_index += 1; Err(TransactionError::InvalidAccountIndex) } else { diff --git a/runtime/src/bloom.rs b/runtime/src/bloom.rs index f9a212bdf0..3739fe099d 100644 --- a/runtime/src/bloom.rs +++ b/runtime/src/bloom.rs @@ -19,6 +19,8 @@ pub struct Bloom { _phantom: PhantomData, } +impl solana_sdk::sanitize::Sanitize for Bloom {} + impl Bloom { pub fn new(num_bits: usize, keys: Vec) -> Self { let bits = BitVec::new_fill(false, num_bits as u64); diff --git a/sdk/src/lib.rs b/sdk/src/lib.rs index e78081b2e0..12bbdd1dcc 100644 --- a/sdk/src/lib.rs +++ b/sdk/src/lib.rs @@ -24,6 +24,7 @@ pub mod program_utils; pub mod pubkey; pub mod rent; pub mod rpc_port; +pub mod sanitize; pub mod short_vec; pub mod slot_hashes; pub mod slot_history; diff --git a/sdk/src/message.rs b/sdk/src/message.rs index 12f71ab555..ea5b9d422e 100644 --- a/sdk/src/message.rs +++ b/sdk/src/message.rs @@ -1,5 +1,6 @@ //! A library for generating a message from a sequence of instructions +use crate::sanitize::{Sanitize, SanitizeError}; use crate::{ hash::Hash, instruction::{AccountMeta, CompiledInstruction, Instruction}, @@ -162,6 +163,31 @@ pub struct Message { pub instructions: Vec, } +impl Sanitize for Message { + fn sanitize(&self) -> std::result::Result<(), SanitizeError> { + if self.header.num_required_signatures as usize > self.account_keys.len() { + return Err(SanitizeError::IndexOutOfBounds); + } + if self.header.num_readonly_unsigned_accounts as usize + + self.header.num_readonly_signed_accounts as usize + > self.account_keys.len() + { + return Err(SanitizeError::IndexOutOfBounds); + } + for ci in &self.instructions { + if ci.program_id_index as usize >= self.account_keys.len() { + return Err(SanitizeError::IndexOutOfBounds); + } + for ai in &ci.accounts { + if *ai as usize >= self.account_keys.len() { + return Err(SanitizeError::IndexOutOfBounds); + } + } + } + Ok(()) + } +} + impl Message { pub fn new_with_compiled_instructions( num_required_signatures: u8, diff --git a/sdk/src/pubkey.rs b/sdk/src/pubkey.rs index 860bcb56ae..302fc558bd 100644 --- a/sdk/src/pubkey.rs +++ b/sdk/src/pubkey.rs @@ -6,6 +6,8 @@ pub use bs58; #[derive(Serialize, Deserialize, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct Pubkey([u8; 32]); +impl crate::sanitize::Sanitize for Pubkey {} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum ParsePubkeyError { WrongSize, diff --git a/sdk/src/sanitize.rs b/sdk/src/sanitize.rs new file mode 100644 index 0000000000..89681d05b5 --- /dev/null +++ b/sdk/src/sanitize.rs @@ -0,0 +1,21 @@ +#[derive(PartialEq, Debug)] +pub enum SanitizeError { + Failed, + IndexOutOfBounds, + ValueOutOfRange, +} + +pub trait Sanitize { + fn sanitize(&self) -> Result<(), SanitizeError> { + Ok(()) + } +} + +impl Sanitize for Vec { + fn sanitize(&self) -> Result<(), SanitizeError> { + for x in self.iter() { + x.sanitize()?; + } + Ok(()) + } +} diff --git a/sdk/src/signature.rs b/sdk/src/signature.rs index 8117b69d77..cb57e0dcdc 100644 --- a/sdk/src/signature.rs +++ b/sdk/src/signature.rs @@ -49,6 +49,8 @@ impl Keypair { #[derive(Serialize, Deserialize, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct Signature(GenericArray); +impl crate::sanitize::Sanitize for Signature {} + impl Signature { pub fn new(signature_slice: &[u8]) -> Self { Self(GenericArray::clone_from_slice(&signature_slice)) diff --git a/sdk/src/transaction.rs b/sdk/src/transaction.rs index 01972e292d..ce36c59ea9 100644 --- a/sdk/src/transaction.rs +++ b/sdk/src/transaction.rs @@ -1,5 +1,6 @@ //! Defines a Transaction type to package an atomic sequence of instructions. +use crate::sanitize::{Sanitize, SanitizeError}; use crate::{ hash::Hash, instruction::{CompiledInstruction, Instruction, InstructionError}, @@ -83,6 +84,18 @@ pub struct Transaction { pub message: Message, } +impl Sanitize for Transaction { + fn sanitize(&self) -> std::result::Result<(), SanitizeError> { + if self.message.header.num_required_signatures as usize > self.signatures.len() { + return Err(SanitizeError::IndexOutOfBounds); + } + if self.signatures.len() > self.message.account_keys.len() { + return Err(SanitizeError::IndexOutOfBounds); + } + self.message.sanitize() + } +} + impl Transaction { pub fn new_unsigned(message: Message) -> Self { Self { @@ -361,22 +374,6 @@ impl Transaction { .iter() .all(|signature| *signature != Signature::default()) } - - /// Verify that references in the instructions are valid - pub fn verify_refs(&self) -> bool { - let message = self.message(); - for instruction in &message.instructions { - if (instruction.program_id_index as usize) >= message.account_keys.len() { - return false; - } - for account_index in &instruction.accounts { - if (*account_index as usize) >= message.account_keys.len() { - return false; - } - } - } - true - } } #[cfg(test)] @@ -415,7 +412,7 @@ mod tests { vec![prog1, prog2], instructions, ); - assert!(tx.verify_refs()); + assert!(tx.sanitize().is_ok()); assert_eq!(tx.key(0, 0), Some(&key.pubkey())); assert_eq!(tx.signer_key(0, 0), Some(&key.pubkey())); @@ -449,7 +446,7 @@ mod tests { vec![], instructions, ); - assert!(!tx.verify_refs()); + assert_eq!(tx.sanitize(), Err(SanitizeError::IndexOutOfBounds)); } #[test] fn test_refs_invalid_account() { @@ -463,7 +460,54 @@ mod tests { instructions, ); assert_eq!(*get_program_id(&tx, 0), Pubkey::default()); - assert!(!tx.verify_refs()); + assert_eq!(tx.sanitize(), Err(SanitizeError::IndexOutOfBounds)); + } + + #[test] + fn test_sanitize_txs() { + let key = Keypair::new(); + let id0 = Pubkey::default(); + let program_id = Pubkey::new_rand(); + let ix = Instruction::new( + program_id, + &0, + vec![ + AccountMeta::new(key.pubkey(), true), + AccountMeta::new(id0, true), + ], + ); + let ixs = vec![ix]; + let mut tx = Transaction::new_with_payer(ixs, Some(&key.pubkey())); + let o = tx.clone(); + assert_eq!(tx.sanitize(), Ok(())); + assert_eq!(tx.message.account_keys.len(), 3); + + tx = o.clone(); + tx.message.header.num_required_signatures = 3; + assert_eq!(tx.sanitize(), Err(SanitizeError::IndexOutOfBounds)); + + tx = o.clone(); + tx.message.header.num_readonly_signed_accounts = 4; + tx.message.header.num_readonly_unsigned_accounts = 0; + assert_eq!(tx.sanitize(), Err(SanitizeError::IndexOutOfBounds)); + + tx = o.clone(); + tx.message.header.num_readonly_signed_accounts = 2; + tx.message.header.num_readonly_unsigned_accounts = 2; + assert_eq!(tx.sanitize(), Err(SanitizeError::IndexOutOfBounds)); + + tx = o.clone(); + tx.message.header.num_readonly_signed_accounts = 0; + tx.message.header.num_readonly_unsigned_accounts = 4; + assert_eq!(tx.sanitize(), Err(SanitizeError::IndexOutOfBounds)); + + tx = o.clone(); + tx.message.instructions[0].program_id_index = 3; + assert_eq!(tx.sanitize(), Err(SanitizeError::IndexOutOfBounds)); + + tx = o.clone(); + tx.message.instructions[0].accounts[0] = 3; + assert_eq!(tx.sanitize(), Err(SanitizeError::IndexOutOfBounds)); } fn create_sample_transaction() -> Transaction {