#![allow(clippy::integer_arithmetic)] //! Borsh utils use { borsh::{ maybestd::io::{Error, Write}, schema::{BorshSchema, Declaration, Definition, Fields}, BorshDeserialize, BorshSerialize, }, std::collections::HashMap, }; /// Get packed length for the given BorchSchema Declaration fn get_declaration_packed_len( declaration: &str, definitions: &HashMap, ) -> usize { match definitions.get(declaration) { Some(Definition::Array { length, elements }) => { *length as usize * get_declaration_packed_len(elements, definitions) } Some(Definition::Enum { variants }) => { 1 + variants .iter() .map(|(_, declaration)| get_declaration_packed_len(declaration, definitions)) .max() .unwrap_or(0) } Some(Definition::Struct { fields }) => match fields { Fields::NamedFields(named_fields) => named_fields .iter() .map(|(_, declaration)| get_declaration_packed_len(declaration, definitions)) .sum(), Fields::UnnamedFields(declarations) => declarations .iter() .map(|declaration| get_declaration_packed_len(declaration, definitions)) .sum(), Fields::Empty => 0, }, Some(Definition::Sequence { elements: _elements, }) => panic!("Missing support for Definition::Sequence"), Some(Definition::Tuple { elements }) => elements .iter() .map(|element| get_declaration_packed_len(element, definitions)) .sum(), None => match declaration { "bool" | "u8" | "i8" => 1, "u16" | "i16" => 2, "u32" | "i32" => 4, "u64" | "i64" => 8, "u128" | "i128" => 16, "nil" => 0, _ => panic!("Missing primitive type: {}", declaration), }, } } /// Get the worst-case packed length for the given BorshSchema /// /// Note: due to the serializer currently used by Borsh, this function cannot /// be used on-chain in the Solana BPF execution environment. pub fn get_packed_len() -> usize { let schema_container = S::schema_container(); get_declaration_packed_len(&schema_container.declaration, &schema_container.definitions) } /// Deserializes without checking that the entire slice has been consumed /// /// Normally, `try_from_slice` checks the length of the final slice to ensure /// that the deserialization uses up all of the bytes in the slice. /// /// Note that there is a potential issue with this function. Any buffer greater than /// or equal to the expected size will properly deserialize. For example, if the /// user passes a buffer destined for a different type, the error won't get caught /// as easily. pub fn try_from_slice_unchecked(data: &[u8]) -> Result { let mut data_mut = data; let result = T::deserialize(&mut data_mut)?; Ok(result) } /// Helper struct which to count how much data would be written during serialization #[derive(Default)] struct WriteCounter { count: usize, } impl Write for WriteCounter { fn write(&mut self, data: &[u8]) -> Result { let amount = data.len(); self.count += amount; Ok(amount) } fn flush(&mut self) -> Result<(), Error> { Ok(()) } } /// Get the packed length for the serialized form of this object instance. /// /// Useful when working with instances of types that contain a variable-length /// sequence, such as a Vec or HashMap. Since it is impossible to know the packed /// length only from the type's schema, this can be used when an instance already /// exists, to figure out how much space to allocate in an account. pub fn get_instance_packed_len(instance: &T) -> Result { let mut counter = WriteCounter::default(); instance.serialize(&mut counter)?; Ok(counter.count) } #[cfg(test)] mod tests { use { super::*, borsh::{maybestd::io::ErrorKind, BorshSchema, BorshSerialize}, std::{collections::HashMap, mem::size_of}, }; #[derive(PartialEq, Clone, Debug, BorshSerialize, BorshDeserialize, BorshSchema)] enum TestEnum { NoValue, Number(u32), Struct { #[allow(dead_code)] number: u64, #[allow(dead_code)] array: [u8; 8], }, } // for test simplicity impl Default for TestEnum { fn default() -> Self { Self::NoValue } } #[derive(Default, BorshSerialize, BorshDeserialize, BorshSchema)] struct TestStruct { pub array: [u64; 16], pub number_u128: u128, pub number_u32: u32, pub tuple: (u8, u16), pub enumeration: TestEnum, pub r#bool: bool, } #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)] struct Child { pub data: [u8; 64], } #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)] struct Parent { pub data: Vec, } #[test] fn unchecked_deserialization() { let data = vec![ Child { data: [0u8; 64] }, Child { data: [1u8; 64] }, Child { data: [2u8; 64] }, ]; let parent = Parent { data }; // exact size, both work let mut byte_vec = vec![0u8; 4 + get_packed_len::() * 3]; let mut bytes = byte_vec.as_mut_slice(); parent.serialize(&mut bytes).unwrap(); let deserialized = Parent::try_from_slice(&byte_vec).unwrap(); assert_eq!(deserialized, parent); let deserialized = try_from_slice_unchecked::(&byte_vec).unwrap(); assert_eq!(deserialized, parent); // too big, only unchecked works let mut byte_vec = vec![0u8; 4 + get_packed_len::() * 10]; let mut bytes = byte_vec.as_mut_slice(); parent.serialize(&mut bytes).unwrap(); let err = Parent::try_from_slice(&byte_vec).unwrap_err(); assert_eq!(err.kind(), ErrorKind::InvalidData); let deserialized = try_from_slice_unchecked::(&byte_vec).unwrap(); assert_eq!(deserialized, parent); } #[test] fn packed_len() { assert_eq!( get_packed_len::(), size_of::() + size_of::() + size_of::() * 8 ); assert_eq!( get_packed_len::(), size_of::() * 16 + size_of::() + size_of::() + size_of::() + size_of::() + size_of::() + get_packed_len::() ); } #[test] fn instance_packed_len_matches_packed_len() { let enumeration = TestEnum::Struct { number: u64::MAX, array: [255; 8], }; assert_eq!( get_packed_len::(), get_instance_packed_len(&enumeration).unwrap(), ); let test_struct = TestStruct { enumeration, ..TestStruct::default() }; assert_eq!( get_packed_len::(), get_instance_packed_len(&test_struct).unwrap(), ); assert_eq!( get_packed_len::(), get_instance_packed_len(&0u8).unwrap(), ); assert_eq!( get_packed_len::(), get_instance_packed_len(&0u16).unwrap(), ); assert_eq!( get_packed_len::(), get_instance_packed_len(&0u32).unwrap(), ); assert_eq!( get_packed_len::(), get_instance_packed_len(&0u64).unwrap(), ); assert_eq!( get_packed_len::(), get_instance_packed_len(&0u128).unwrap(), ); assert_eq!( get_packed_len::<[u8; 10]>(), get_instance_packed_len(&[0u8; 10]).unwrap(), ); assert_eq!( get_packed_len::<(i8, i16, i32, i64, i128)>(), get_instance_packed_len(&(i8::MAX, i16::MAX, i32::MAX, i64::MAX, i128::MAX)).unwrap(), ); } #[test] fn instance_packed_len_with_vec() { let data = vec![ Child { data: [0u8; 64] }, Child { data: [1u8; 64] }, Child { data: [2u8; 64] }, Child { data: [3u8; 64] }, Child { data: [4u8; 64] }, Child { data: [5u8; 64] }, ]; let parent = Parent { data }; assert_eq!( get_instance_packed_len(&parent).unwrap(), 4 + parent.data.len() * get_packed_len::() ); } #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)] struct StructWithHashMap { data: HashMap, } #[test] fn instance_packed_len_with_varying_sizes_in_hashmap() { let mut data = HashMap::new(); let string1 = "the first string, it's actually really really long".to_string(); let enum1 = TestEnum::NoValue; let string2 = "second string, shorter".to_string(); let enum2 = TestEnum::Number(u32::MAX); let string3 = "third".to_string(); let enum3 = TestEnum::Struct { number: 0, array: [0; 8], }; data.insert(string1.clone(), enum1.clone()); data.insert(string2.clone(), enum2.clone()); data.insert(string3.clone(), enum3.clone()); let instance = StructWithHashMap { data }; assert_eq!( get_instance_packed_len(&instance).unwrap(), 4 + get_instance_packed_len(&string1).unwrap() + get_instance_packed_len(&enum1).unwrap() + get_instance_packed_len(&string2).unwrap() + get_instance_packed_len(&enum2).unwrap() + get_instance_packed_len(&string3).unwrap() + get_instance_packed_len(&enum3).unwrap() ); } }