diff --git a/sdk/program/src/borsh.rs b/sdk/program/src/borsh.rs index a2df53bfdd..38454eb129 100644 --- a/sdk/program/src/borsh.rs +++ b/sdk/program/src/borsh.rs @@ -1,6 +1,12 @@ //! Borsh utils -use borsh::schema::{BorshSchema, Declaration, Definition, Fields}; -use std::collections::HashMap; +use { + borsh::{ + maybestd::io::Error, + schema::{BorshSchema, Declaration, Definition, Fields}, + BorshDeserialize, + }, + std::collections::HashMap, +}; /// Get packed length for the given BorchSchema Declaration fn get_declaration_packed_len( @@ -53,3 +59,101 @@ pub fn get_packed_len() -> usize { let schema_container = S::schema_container(); get_declaration_packed_len(&schema_container.declaration, &schema_container.definitions) } + +/// Deserializes without checking that the entire slice has been consumed +/// +/// Normally, `try_from_slice` checks the length of the final slice to ensure +/// that the deserialization uses up all of the bytes in the slice. +/// +/// Note that there is a potential issue with this function. Any buffer greater than +/// or equal to the expected size will properly deserialize. For example, if the +/// user passes a buffer destined for a different type, the error won't get caught +/// as easily. +pub fn try_from_slice_unchecked(data: &[u8]) -> Result { + let mut data_mut = data; + let result = T::deserialize(&mut data_mut)?; + Ok(result) +} + +#[cfg(test)] +mod tests { + use { + super::*, + borsh::{maybestd::io::ErrorKind, BorshSchema, BorshSerialize}, + std::mem::size_of, + }; + + #[derive(BorshSerialize, BorshDeserialize, BorshSchema)] + enum TestEnum { + NoValue, + Value(u32), + StructValue { + #[allow(dead_code)] + number: u64, + #[allow(dead_code)] + array: [u8; 8], + }, + } + + #[derive(BorshSerialize, BorshDeserialize, BorshSchema)] + struct TestStruct { + pub array: [u64; 16], + pub number: u128, + pub tuple: (u8, u16), + pub enumeration: TestEnum, + } + + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)] + struct Child { + pub data: [u8; 64], + } + + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)] + struct Parent { + pub data: Vec, + } + + #[test] + fn unchecked_deserialization() { + let data = vec![ + Child { data: [0u8; 64] }, + Child { data: [1u8; 64] }, + Child { data: [2u8; 64] }, + ]; + let parent = Parent { data }; + + // exact size, both work + let mut byte_vec = vec![0u8; 4 + get_packed_len::() * 3]; + let mut bytes = byte_vec.as_mut_slice(); + parent.serialize(&mut bytes).unwrap(); + let deserialized = Parent::try_from_slice(&byte_vec).unwrap(); + assert_eq!(deserialized, parent); + let deserialized = try_from_slice_unchecked::(&byte_vec).unwrap(); + assert_eq!(deserialized, parent); + + // too big, only unchecked works + let mut byte_vec = vec![0u8; 4 + get_packed_len::() * 10]; + let mut bytes = byte_vec.as_mut_slice(); + parent.serialize(&mut bytes).unwrap(); + let err = Parent::try_from_slice(&byte_vec).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::InvalidData); + let deserialized = try_from_slice_unchecked::(&byte_vec).unwrap(); + assert_eq!(deserialized, parent); + } + + #[test] + fn packed_len() { + assert_eq!( + get_packed_len::(), + size_of::() + size_of::() + size_of::() * 8 + ); + assert_eq!( + get_packed_len::(), + size_of::() * 16 + + size_of::() + + size_of::() + + size_of::() + + get_packed_len::() + ); + } +}