use serde::{ de::{self, Deserializer, SeqAccess, Visitor}, ser::{self, SerializeTuple, Serializer}, {Deserialize, Serialize}, }; use std::{fmt, marker::PhantomData, mem::size_of}; /// Same as u16, but serialized with 1 to 3 bytes. If the value is above /// 0x7f, the top bit is set and the remaining value is stored in the next /// bytes. Each byte follows the same pattern until the 3rd byte. The 3rd /// byte, if needed, uses all 8 bits to store the last byte of the original /// value. #[derive(AbiExample)] pub struct ShortU16(pub u16); impl Serialize for ShortU16 { fn serialize(&self, serializer: S) -> Result where S: Serializer, { // Pass a non-zero value to serialize_tuple() so that serde_json will // generate an open bracket. let mut seq = serializer.serialize_tuple(1)?; let mut rem_len = self.0; loop { let mut elem = (rem_len & 0x7f) as u8; rem_len >>= 7; if rem_len == 0 { seq.serialize_element(&elem)?; break; } else { elem |= 0x80; seq.serialize_element(&elem)?; } } seq.end() } } enum VisitResult { Done(usize, usize), More(usize, usize), Err, } fn visit_byte(elem: u8, len: usize, size: usize) -> VisitResult { let len = len | (elem as usize & 0x7f) << (size * 7); let size = size + 1; let more = elem as usize & 0x80 == 0x80; if size > size_of::() + 1 { VisitResult::Err } else if more { VisitResult::More(len, size) } else { VisitResult::Done(len, size) } } struct ShortLenVisitor; impl<'de> Visitor<'de> for ShortLenVisitor { type Value = ShortU16; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("a multi-byte length") } fn visit_seq(self, mut seq: A) -> Result where A: SeqAccess<'de>, { let mut len: usize = 0; let mut size: usize = 0; loop { let elem: u8 = seq .next_element()? .ok_or_else(|| de::Error::invalid_length(size, &self))?; match visit_byte(elem, len, size) { VisitResult::Done(l, _) => { len = l; break; } VisitResult::More(l, s) => { len = l; size = s; } VisitResult::Err => return Err(de::Error::invalid_length(size + 1, &self)), } } Ok(ShortU16(len as u16)) } } impl<'de> Deserialize<'de> for ShortU16 { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { deserializer.deserialize_tuple(3, ShortLenVisitor) } } /// If you don't want to use the ShortVec newtype, you can do ShortVec /// serialization on an ordinary vector with the following field annotation: /// /// #[serde(with = "short_vec")] /// pub fn serialize( elements: &[T], serializer: S, ) -> Result { // Pass a non-zero value to serialize_tuple() so that serde_json will // generate an open bracket. let mut seq = serializer.serialize_tuple(1)?; let len = elements.len(); if len > std::u16::MAX as usize { return Err(ser::Error::custom("length larger than u16")); } let short_len = ShortU16(len as u16); seq.serialize_element(&short_len)?; for element in elements { seq.serialize_element(element)?; } seq.end() } struct ShortVecVisitor { _t: PhantomData, } impl<'de, T> Visitor<'de> for ShortVecVisitor where T: Deserialize<'de>, { type Value = Vec; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("a Vec with a multi-byte length") } fn visit_seq(self, mut seq: A) -> Result, A::Error> where A: SeqAccess<'de>, { let short_len: ShortU16 = seq .next_element()? .ok_or_else(|| de::Error::invalid_length(0, &self))?; let len = short_len.0 as usize; let mut result = Vec::with_capacity(len); for i in 0..len { let elem = seq .next_element()? .ok_or_else(|| de::Error::invalid_length(i, &self))?; result.push(elem); } Ok(result) } } /// If you don't want to use the ShortVec newtype, you can do ShortVec /// deserialization on an ordinary vector with the following field annotation: /// /// #[serde(with = "short_vec")] /// pub fn deserialize<'de, D, T>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, T: Deserialize<'de>, { let visitor = ShortVecVisitor { _t: PhantomData }; deserializer.deserialize_tuple(std::usize::MAX, visitor) } pub struct ShortVec(pub Vec); impl Serialize for ShortVec { fn serialize(&self, serializer: S) -> Result where S: Serializer, { serialize(&self.0, serializer) } } impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec { fn deserialize(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { deserialize(deserializer).map(ShortVec) } } /// Return the decoded value and how many bytes it consumed. pub fn decode_len(bytes: &[u8]) -> Result<(usize, usize), ()> { let mut len = 0; let mut size = 0; for byte in bytes.iter() { match visit_byte(*byte, len, size) { VisitResult::More(l, s) => { len = l; size = s; } VisitResult::Done(len, size) => return Ok((len, size)), VisitResult::Err => return Err(()), } } Err(()) } #[cfg(test)] mod tests { use super::*; use assert_matches::assert_matches; use bincode::{deserialize, serialize}; /// Return the serialized length. fn encode_len(len: u16) -> Vec { bincode::serialize(&ShortU16(len)).unwrap() } fn assert_len_encoding(len: u16, bytes: &[u8]) { assert_eq!(encode_len(len), bytes, "unexpected usize encoding"); assert_eq!( decode_len(bytes).unwrap(), (len as usize, bytes.len()), "unexpected usize decoding" ); } #[test] fn test_short_vec_encode_len() { assert_len_encoding(0x0, &[0x0]); assert_len_encoding(0x7f, &[0x7f]); assert_len_encoding(0x80, &[0x80, 0x01]); assert_len_encoding(0xff, &[0xff, 0x01]); assert_len_encoding(0x100, &[0x80, 0x02]); assert_len_encoding(0x7fff, &[0xff, 0xff, 0x01]); assert_len_encoding(0xffff, &[0xff, 0xff, 0x03]); } #[test] #[should_panic] fn test_short_vec_decode_zero_len() { decode_len(&[]).unwrap(); } #[test] fn test_short_vec_u8() { let vec = ShortVec(vec![4u8; 32]); let bytes = serialize(&vec).unwrap(); assert_eq!(bytes.len(), vec.0.len() + 1); let vec1: ShortVec = deserialize(&bytes).unwrap(); assert_eq!(vec.0, vec1.0); } #[test] fn test_short_vec_u8_too_long() { let vec = ShortVec(vec![4u8; std::u16::MAX as usize]); assert_matches!(serialize(&vec), Ok(_)); let vec = ShortVec(vec![4u8; std::u16::MAX as usize + 1]); assert_matches!(serialize(&vec), Err(_)); } #[test] fn test_short_vec_json() { let vec = ShortVec(vec![0, 1, 2]); let s = serde_json::to_string(&vec).unwrap(); assert_eq!(s, "[[3],0,1,2]"); } #[test] fn test_decode_len_aliased_values() { let one1 = [0x01]; let one2 = [0x81, 0x00]; let one3 = [0x81, 0x80, 0x00]; let one4 = [0x81, 0x80, 0x80, 0x00]; assert_eq!(decode_len(&one1).unwrap(), (1, 1)); assert_eq!(decode_len(&one2).unwrap(), (1, 2)); assert_eq!(decode_len(&one3).unwrap(), (1, 3)); assert!(decode_len(&one4).is_err()); } }