From 4b8555711e1d8da4b18bcef13fc15f70619da44e Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 18 Nov 2022 14:15:30 +0100 Subject: [PATCH] Add serde serializer --- Cargo.toml | 3 +- src/dec/error.rs | 34 ++++ src/dec/mod.rs | 23 +-- src/enc/error.rs | 21 +++ src/enc/mod.rs | 75 ++++++-- src/lib.rs | 3 + src/serde/error.rs | 51 +++++ src/serde/mod.rs | 7 + src/serde/ser.rs | 450 +++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 630 insertions(+), 37 deletions(-) create mode 100644 src/dec/error.rs create mode 100644 src/enc/error.rs create mode 100644 src/serde/error.rs create mode 100644 src/serde/mod.rs create mode 100644 src/serde/ser.rs diff --git a/Cargo.toml b/Cargo.toml index 5087561..13d1382 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ base64 = "0.13" blake2 = { version = "0.10", optional = true } rand = "0.7" ed25519-dalek = { version = "1.0", optional = true } +serde = { version = "1.0", optional = true, features = ["derive"] } [features] -default = [ "blake2", "ed25519-dalek" ] +default = [ "blake2", "ed25519-dalek", "serde" ] diff --git a/src/dec/error.rs b/src/dec/error.rs new file mode 100644 index 0000000..e168016 --- /dev/null +++ b/src/dec/error.rs @@ -0,0 +1,34 @@ +use std::fmt; + +/// The type of errors returned by helper functions on `Term` +#[derive(Debug, Clone)] +pub enum TypeError { + /// The term could not be decoded in the given type + WrongType(&'static str), + /// The term is not an array of the requested length + WrongLength(usize, usize), + /// The dictionnary is missing a key + MissingKey(String), + /// The dictionnary contains an invalid key + UnexpectedKey(String), + /// The underlying raw string contains garbage (should not happen in theory) + Garbage, +} + +impl From for TypeError { + fn from(_x: std::str::Utf8Error) -> TypeError { + TypeError::Garbage + } +} + +impl std::fmt::Display for TypeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + TypeError::WrongType(t) => write!(f, "Not a {}", t), + TypeError::WrongLength(n, m) => write!(f, "Expected {} items, got {}", m, n), + TypeError::MissingKey(k) => write!(f, "Missing key `{}` in dict", k), + TypeError::UnexpectedKey(k) => write!(f, "Spurrious/unexpected key `{}` in dict", k), + TypeError::Garbage => write!(f, "Garbage in underlying data"), + } + } +} diff --git a/src/dec/mod.rs b/src/dec/mod.rs index 4d9e6f7..dca9fec 100644 --- a/src/dec/mod.rs +++ b/src/dec/mod.rs @@ -1,11 +1,13 @@ //! Functions to decode nettext and helpers to map it to data structures mod decode; +mod error; use std::collections::HashMap; use crate::crypto; pub use decode::*; +pub use error::TypeError; /// A parsed nettext term, with many helpers for destructuring /// @@ -63,27 +65,6 @@ impl<'a> From> for Term<'a, 'static> { // ---- PUBLIC IMPLS ---- -/// The type of errors returned by helper functions on `Term` -#[derive(Debug, Clone)] -pub enum TypeError { - /// The term could not be decoded in the given type - WrongType(&'static str), - /// The term is not an array of the requested length - WrongLength(usize, usize), - /// The dictionnary is missing a key - MissingKey(String), - /// The dictionnary contains an invalid key - UnexpectedKey(String), - /// The underlying raw string contains garbage (should not happen in theory) - Garbage, -} - -impl From for TypeError { - fn from(_x: std::str::Utf8Error) -> TypeError { - TypeError::Garbage - } -} - impl<'a, 'b> Term<'a, 'b> { // ---- STRUCTURAL MAPPINGS ---- diff --git a/src/enc/error.rs b/src/enc/error.rs new file mode 100644 index 0000000..78f9a15 --- /dev/null +++ b/src/enc/error.rs @@ -0,0 +1,21 @@ +use std::fmt; + +/// An error that happenned when creating a nettext encoder term +#[derive(Debug)] +pub enum Error { + InvalidCharacter(u8), + InvalidRaw, + NotADictionnary, + ListInList, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::InvalidCharacter(c) => write!(f, "Invalid character '{}'", *c as char), + Error::InvalidRaw => write!(f, "Invalid RAW nettext litteral"), + Error::NotADictionnary => write!(f, "Tried to insert into a term that isn't a dictionnary"), + Error::ListInList => write!(f, "Refusing to build nested lists with list(), use either list_flatten() or list_nested()"), + } + } +} diff --git a/src/enc/mod.rs b/src/enc/mod.rs index dd963b9..37e300a 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -17,30 +17,26 @@ //! ]).unwrap().encode(); //! ``` +mod error; + use std::collections::HashMap; +use std::borrow::{Cow, Borrow}; use crate::dec::{self, decode}; use crate::{is_string_char, is_whitespace}; +pub use error::Error; + /// A term meant to be encoded into a nettext representation pub struct Term<'a>(T<'a>); enum T<'a> { Str(&'a [u8]), OwnedStr(Vec), - Dict(HashMap<&'a [u8], T<'a>>), + Dict(HashMap, T<'a>>), List(Vec>), } -/// An error that happenned when creating a nettext encoder term -#[derive(Debug)] -pub enum Error { - InvalidCharacter(u8), - InvalidRaw, - NotADictionnary, - ListInList, -} - pub type Result<'a> = std::result::Result, Error>; // ---- helpers to transform datatypes into encoder terms ---- @@ -74,6 +70,16 @@ pub fn string(s: &str) -> Result<'_> { Ok(Term(T::Str(s.as_bytes()))) } +/// Same as `string` but takes an owned String +pub fn string_owned(s: String) -> Result<'static> { + for c in s.as_bytes().iter() { + if !(is_string_char(*c) || is_whitespace(*c)) { + return Err(Error::InvalidCharacter(*c)); + } + } + Ok(Term(T::OwnedStr(s.into_bytes()))) +} + /// Include a raw nettext value /// /// ``` @@ -88,6 +94,13 @@ pub fn raw(bytes: &[u8]) -> Result<'_> { Ok(Term(T::Str(bytes))) } +pub(crate) fn safe_raw(bytes: &[u8]) -> Term<'_> { + Term(T::Str(bytes)) +} +pub(crate) fn safe_raw_owned(bytes: Vec) -> Term<'static> { + Term(T::OwnedStr(bytes)) +} + /// Term corresponding to a list of terms /// /// ``` @@ -109,6 +122,30 @@ pub fn list<'a, I: IntoIterator>>(terms: I) -> Result<'a> { Ok(Term(T::List(tmp))) } +/// Term corresponding to a list of terms. Sub-lists are flattenned. +pub fn list_flatten<'a, I: IntoIterator>>(terms: I) -> Result<'a> { + let mut tmp = Vec::with_capacity(8); + for t in terms { + match t.0 { + T::List(t) => tmp.extend(t), + x => tmp.push(x), + } + } + Ok(Term(T::List(tmp))) +} + +/// Term corresponding to a list of terms. Sub-lists are represented as NESTED: `{ . = sub list items }`. +pub fn list_nested<'a, I: IntoIterator>>(terms: I) -> Result<'a> { + let mut tmp = Vec::with_capacity(8); + for t in terms { + match t.0 { + T::List(t) => tmp.push(Term(T::List(t)).nested().0), + x => tmp.push(x), + } + } + Ok(Term(T::List(tmp))) +} + /// Term corresponding to a dictionnary of items /// /// ``` @@ -122,7 +159,15 @@ pub fn list<'a, I: IntoIterator>>(terms: I) -> Result<'a> { pub fn dict<'a, I: IntoIterator)>>(pairs: I) -> Term<'a> { let mut tmp = HashMap::new(); for (k, v) in pairs { - tmp.insert(k.as_bytes(), v.0); + tmp.insert(Cow::from(k.as_bytes()), v.0); + } + Term(T::Dict(tmp)) +} + +pub(crate) fn dict_owned_u8<'a, I: IntoIterator, Term<'a>)>>(pairs: I) -> Term<'a> { + let mut tmp = HashMap::new(); + for (k, v) in pairs { + tmp.insert(Cow::from(k), v.0); } Term(T::Dict(tmp)) } @@ -178,7 +223,7 @@ impl<'a> Term<'a> { pub fn insert(self, k: &'a str, v: Term<'a>) -> Result<'a> { match self.0 { T::Dict(mut d) => { - d.insert(k.as_bytes(), v.0); + d.insert(Cow::from(k.as_bytes()), v.0); Ok(Term(T::Dict(d))) } _ => Err(Error::NotADictionnary), @@ -223,7 +268,7 @@ impl<'a> T<'a> { } else if d.len() == 1 { buf.extend_from_slice(b"{ "); let (k, v) = d.into_iter().next().unwrap(); - buf.extend_from_slice(k); + buf.extend_from_slice(k.borrow()); buf.extend_from_slice(b" = "); v.encode_aux(buf, indent + 2, false); buf.extend_from_slice(b" }"); @@ -233,11 +278,11 @@ impl<'a> T<'a> { let mut keys = d.keys().cloned().collect::>(); keys.sort(); for k in keys { - let v = d.remove(k).unwrap(); + let v = d.remove(&k).unwrap(); for _ in 0..indent2 { buf.push(b' '); } - buf.extend_from_slice(k); + buf.extend_from_slice(k.borrow()); buf.extend_from_slice(b" = "); v.encode_aux(buf, indent2, false); buf.extend_from_slice(b",\n"); diff --git a/src/lib.rs b/src/lib.rs index cbf1acc..ecd5161 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -85,6 +85,9 @@ pub mod crypto; pub mod dec; pub mod enc; +#[cfg(feature = "serde")] +pub mod serde; + // ---- syntactic elements of the data format ---- pub(crate) const DICT_OPEN: u8 = b'{'; diff --git a/src/serde/error.rs b/src/serde/error.rs new file mode 100644 index 0000000..dec2609 --- /dev/null +++ b/src/serde/error.rs @@ -0,0 +1,51 @@ +use std; +use std::fmt::{self, Display}; + +use serde::{de, ser}; + +use crate::{dec, enc}; + +pub type Result = std::result::Result; + +#[derive(Debug)] +pub enum Error { + Message(String), + Encode(enc::Error), + Type(dec::TypeError), +} + +impl From for Error { + fn from(e: enc::Error) -> Self { + Error::Encode(e) + } +} + +impl From for Error { + fn from(e: dec::TypeError) -> Self { + Error::Type(e) + } +} + +impl ser::Error for Error { + fn custom(msg: T) -> Self { + Error::Message(msg.to_string()) + } +} + +impl de::Error for Error { + fn custom(msg: T) -> Self { + Error::Message(msg.to_string()) + } +} + +impl Display for Error { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::Message(msg) => formatter.write_str(msg), + Error::Encode(err) => write!(formatter, "Encode: {}", err), + Error::Type(err) => write!(formatter, "Type: {}", err), + } + } +} + +impl std::error::Error for Error {} diff --git a/src/serde/mod.rs b/src/serde/mod.rs new file mode 100644 index 0000000..8062f43 --- /dev/null +++ b/src/serde/mod.rs @@ -0,0 +1,7 @@ +//mod de; +mod error; +mod ser; + +//pub use de::{from_str, Deserializer}; +pub use error::{Error, Result}; +pub use ser::{to_bytes, Serializer}; diff --git a/src/serde/ser.rs b/src/serde/ser.rs new file mode 100644 index 0000000..590c5ec --- /dev/null +++ b/src/serde/ser.rs @@ -0,0 +1,450 @@ +use serde::{ser, Serialize}; + +use crate::enc::*; +use crate::serde::error::{Error, Result}; +use serde::ser::Error as SerError; + +pub struct Serializer; + +pub fn to_bytes(value: &T) -> Result> +where + T: Serialize, +{ + Ok(value.serialize(&mut Serializer)?.encode()) +} + +impl<'a> ser::Serializer for &'a mut Serializer { + type Ok = Term<'static>; + + type Error = Error; + + type SerializeSeq = SeqSerializer; + type SerializeTuple = SeqSerializer; + type SerializeTupleStruct = SeqSerializer; + type SerializeTupleVariant = SeqSerializer; + type SerializeMap = MapSerializer; + type SerializeStruct = StructSerializer; + type SerializeStructVariant = StructVariantSerializer; + + fn serialize_bool(self, v: bool) -> Result { + Ok(if v { + safe_raw(b"true") + } else { + safe_raw(b"false") + }) + } + + fn serialize_i8(self, v: i8) -> Result { + self.serialize_i64(i64::from(v)) + } + + fn serialize_i16(self, v: i16) -> Result { + self.serialize_i64(i64::from(v)) + } + + fn serialize_i32(self, v: i32) -> Result { + self.serialize_i64(i64::from(v)) + } + + fn serialize_i64(self, v: i64) -> Result { + Ok(safe_raw_owned(v.to_string().into_bytes())) + } + + fn serialize_u8(self, v: u8) -> Result { + self.serialize_u64(u64::from(v)) + } + + fn serialize_u16(self, v: u16) -> Result { + self.serialize_u64(u64::from(v)) + } + + fn serialize_u32(self, v: u32) -> Result { + self.serialize_u64(u64::from(v)) + } + + fn serialize_u64(self, v: u64) -> Result { + Ok(safe_raw_owned(v.to_string().into_bytes())) + } + + fn serialize_f32(self, v: f32) -> Result { + self.serialize_f64(f64::from(v)) + } + + fn serialize_f64(self, v: f64) -> Result { + Ok(string_owned(v.to_string())?) + } + + fn serialize_char(self, v: char) -> Result { + self.serialize_str(&v.to_string()) + } + + fn serialize_str(self, v: &str) -> Result { + Ok(bytes(v.as_bytes())) + } + + fn serialize_bytes(self, v: &[u8]) -> Result { + Ok(bytes(v)) + } + + fn serialize_none(self) -> Result { + Ok(dict([])) + } + + fn serialize_some(self, value: &T) -> Result + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + fn serialize_unit(self) -> Result { + Ok(dict([])) + } + + fn serialize_unit_struct(self, name: &'static str) -> Result { + Ok(string(name)?) + } + + fn serialize_unit_variant( + self, + name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + Ok(string_owned(format!("{}.{}", name, variant))?) + } + + fn serialize_newtype_struct(self, name: &'static str, value: &T) -> Result + where + T: ?Sized + Serialize, + { + Ok(list_flatten([string(name)?, value.serialize(self)?])?) + } + + fn serialize_newtype_variant( + self, + name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + Ok(list_flatten([ + string_owned(format!("{}.{}", name, variant))?, + value.serialize(self)?, + ])?) + } + + fn serialize_seq(self, _len: Option) -> Result { + Ok(SeqSerializer { items: vec![] }) + } + + fn serialize_tuple(self, len: usize) -> Result { + Ok(SeqSerializer { items: Vec::with_capacity(len) }) + } + + fn serialize_tuple_struct( + self, + name: &'static str, + len: usize, + ) -> Result { + let mut items = Vec::with_capacity(len + 1); + items.push(string(name)?); + Ok(SeqSerializer { + items, + }) + } + + fn serialize_tuple_variant( + self, + name: &'static str, + _variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + let mut items = Vec::with_capacity(len + 1); + items.push(string_owned(format!("{}.{}", name, variant))?); + Ok(SeqSerializer { + items, + }) + } + + fn serialize_map(self, _len: Option) -> Result { + Ok(MapSerializer { + next: None, + fields: vec![], + }) + } + + fn serialize_struct(self, name: &'static str, len: usize) -> Result { + Ok(StructSerializer { + name, + fields: Vec::with_capacity(len), + }) + } + + fn serialize_struct_variant( + self, + name: &'static str, + _variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + Ok(StructVariantSerializer { + name, + variant, + fields: Vec::with_capacity(len), + }) + } +} + +// -- sub-serializers -- + +pub struct SeqSerializer { + items: Vec>, +} +impl<'a> ser::SerializeSeq for SeqSerializer { + type Ok = Term<'static>; + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.items.push(value.serialize(&mut Serializer)?); + Ok(()) + } + + fn end(self) -> Result { + Ok(list_nested(self.items.into_iter())?) + } +} + +impl<'a> ser::SerializeTuple for SeqSerializer { + type Ok = Term<'static>; + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.items.push(value.serialize(&mut Serializer)?); + Ok(()) + } + + fn end(self) -> Result { + Ok(list_nested(self.items.into_iter())?) + } +} + +impl<'a> ser::SerializeTupleStruct for SeqSerializer { + type Ok = Term<'static>; + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.items.push(value.serialize(&mut Serializer)?); + Ok(()) + } + + fn end(self) -> Result { + Ok(list_nested(self.items.into_iter())?) + } +} + +impl<'a> ser::SerializeTupleVariant for SeqSerializer { + type Ok = Term<'static>; + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.items.push(value.serialize(&mut Serializer)?); + Ok(()) + } + + fn end(self) -> Result { + Ok(list_nested(self.items.into_iter())?) + } +} + +pub struct MapSerializer { + next: Option>, + fields: Vec<(Vec, Term<'static>)>, +} + +impl<'a> ser::SerializeMap for MapSerializer { + type Ok = Term<'static>; + type Error = Error; + + fn serialize_key(&mut self, key: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.next = Some(key.serialize(&mut Serializer)?.encode()); + Ok(()) + } + + fn serialize_value(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.fields.push(( + self.next.take().ok_or(Self::Error::custom("no key"))?, + value.serialize(&mut Serializer)?, + )); + Ok(()) + } + + fn end(self) -> Result> { + Ok(dict_owned_u8(self.fields.into_iter())) + } +} + +pub struct StructSerializer { + name: &'static str, + fields: Vec<(&'static str, Term<'static>)>, +} + +impl ser::SerializeStruct for StructSerializer { + type Ok = Term<'static>; + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.fields.push((key, value.serialize(&mut Serializer)?)); + Ok(()) + } + + fn end(self) -> Result> { + Ok(list([string(self.name)?, dict(self.fields.into_iter())])?) + } +} + +pub struct StructVariantSerializer { + name: &'static str, + variant: &'static str, + fields: Vec<(&'static str, Term<'static>)>, +} + +impl<'a> ser::SerializeStructVariant for StructVariantSerializer { + type Ok = Term<'static>; + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.fields.push((key, value.serialize(&mut Serializer)?)); + Ok(()) + } + + fn end(self) -> Result> { + Ok(list([ + string_owned(format!("{}.{}", self.name, self.variant))?, + dict(self.fields.into_iter()), + ])?) + } +} + +//////////////////////////////////////////////////////////////////////////////// + +#[cfg(test)] +mod tests { + use super::*; + use serde::Serialize; + use std::collections::HashMap; + + #[test] + fn test_struct() { + #[derive(Serialize)] + struct Test { + int: u32, + seq: Vec<&'static str>, + } + + let test = Test { + int: 1, + seq: vec!["a", "b"], + }; + let expected = br#"Test { + int = 1, + seq = YQ Yg, + }"#; + assert_eq!(&to_bytes(&test).unwrap(), expected); + } + + #[test] + fn test_enum() { + #[derive(Serialize)] + enum E { + Unit, + Newtype(u32), + Tuple(u32, u32), + Struct { a: u32 }, + } + + let u = E::Unit; + let expected = br#"E.Unit"#; + assert_eq!(&to_bytes(&u).unwrap(), expected); + + let n = E::Newtype(1); + let expected = br#"E.Newtype 1"#; + assert_eq!(&to_bytes(&n).unwrap(), expected); + + let t = E::Tuple(1, 2); + let expected = br#"E.Tuple 1 2"#; + assert_eq!(&to_bytes(&t).unwrap(), expected); + + let s = E::Struct { a: 1 }; + let expected = br#"E.Struct { a = 1 }"#; + assert_eq!(&to_bytes(&s).unwrap(), expected); + } + + #[test] + fn test_seq() { + let u = (1, 2, 3, 4); + let expected = br#"1 2 3 4"#; + assert_eq!(&to_bytes(&u).unwrap(), expected); + + let n = (1, 2, (2, 3, 4), 5, 6); + let expected = br#"1 2 { . = 2 3 4 } 5 6"#; + assert_eq!(&to_bytes(&n).unwrap(), expected); + + let t = [1, 2, 3, 4]; + let expected = br#"1 2 3 4"#; + assert_eq!(&to_bytes(&t).unwrap(), expected); + + let s = [[1, 2], [2, 3], [3, 4]]; + let expected = br#"{ . = 1 2 } { . = 2 3 } { . = 3 4 }"#; + assert_eq!(&to_bytes(&s).unwrap(), expected); + } + + #[test] + fn test_dict() { + let mut d = HashMap::new(); + d.insert("hello", "world"); + d.insert("dont", "panic"); + let expected = br#"{ + ZG9udA = cGFuaWM, + aGVsbG8 = d29ybGQ, +}"#; + assert_eq!(&to_bytes(&d).unwrap(), expected); + + let mut d = HashMap::new(); + d.insert(12, vec![42, 125]); + d.insert(33, vec![19, 22, 21]); + let expected = br#"{ + 12 = 42 125, + 33 = 19 22 21, +}"#; + assert_eq!(&to_bytes(&d).unwrap(), expected); + } +}