From 1bcefaa1ae1fa063b1187a51b41c765cd24528ed Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 18 Nov 2022 16:26:44 +0100 Subject: [PATCH] Deserializer seems to work a bit --- src/dec/decode.rs | 34 +--- src/dec/error.rs | 40 ++++ src/dec/mod.rs | 6 +- src/enc/mod.rs | 2 +- src/serde/de.rs | 493 +++++++++++++++++++++++++++++++++++++++++++++ src/serde/error.rs | 32 +++ src/serde/mod.rs | 6 +- src/serde/ser.rs | 39 ++-- 8 files changed, 595 insertions(+), 57 deletions(-) create mode 100644 src/serde/de.rs diff --git a/src/dec/decode.rs b/src/dec/decode.rs index 33aadee..73f6b26 100644 --- a/src/dec/decode.rs +++ b/src/dec/decode.rs @@ -8,43 +8,11 @@ use nom::{ IResult, InputLength, }; -use crate::dec::{debug, AnyTerm, NonListTerm, Term}; +use crate::dec::{AnyTerm, DecodeError, NonListTerm, Term}; use crate::{is_string_char, is_whitespace, DICT_ASSIGN, DICT_CLOSE, DICT_DELIM, DICT_OPEN}; // ---- -/// The error kind returned by the `decode` function. -#[derive(Eq, PartialEq)] -pub enum DecodeError<'a> { - /// Indicates that there is trailing garbage at the end of the decoded string - Garbage(&'a [u8]), - /// Indicates that the entered string does not represent a complete nettext term - IncompleteInput, - /// Indicates a syntax error in the decoded term - NomError(&'a [u8], nom::error::ErrorKind), -} - -impl<'a> std::fmt::Debug for DecodeError<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { - match self { - DecodeError::Garbage(g) => write!(f, "Garbage: `{}`", debug(g)), - DecodeError::IncompleteInput => write!(f, "Incomplete input"), - DecodeError::NomError(s, e) => write!(f, "Nom: {:?}, at: `{}`", e, debug(s)), - } - } -} - -impl<'a> From>> for DecodeError<'a> { - fn from(e: nom::Err>) -> DecodeError<'a> { - match e { - nom::Err::Incomplete(_) => DecodeError::IncompleteInput, - nom::Err::Error(e) | nom::Err::Failure(e) => DecodeError::NomError(e.input, e.code), - } - } -} - -// ---- - /// Decodes a nettext string into the term it represents. pub fn decode(input: &[u8]) -> std::result::Result, DecodeError<'_>> { let (rest, term) = decode_term(input)?; diff --git a/src/dec/error.rs b/src/dec/error.rs index e168016..9f7e224 100644 --- a/src/dec/error.rs +++ b/src/dec/error.rs @@ -1,5 +1,7 @@ use std::fmt; +use crate::dec::debug; + /// The type of errors returned by helper functions on `Term` #[derive(Debug, Clone)] pub enum TypeError { @@ -32,3 +34,41 @@ impl std::fmt::Display for TypeError { } } } + +// ---- + +/// The error kind returned by the `decode` function. +#[derive(Eq, PartialEq)] +pub enum DecodeError<'a> { + /// Indicates that there is trailing garbage at the end of the decoded string + Garbage(&'a [u8]), + /// Indicates that the entered string does not represent a complete nettext term + IncompleteInput, + /// Indicates a syntax error in the decoded term + NomError(&'a [u8], nom::error::ErrorKind), +} + +impl<'a> std::fmt::Display for DecodeError<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + match self { + DecodeError::Garbage(g) => write!(f, "Garbage: `{}`", debug(g)), + DecodeError::IncompleteInput => write!(f, "Incomplete input"), + DecodeError::NomError(s, e) => write!(f, "Nom: {:?}, at: `{}`", e, debug(s)), + } + } +} + +impl<'a> std::fmt::Debug for DecodeError<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + std::fmt::Display::fmt(self, f) + } +} + +impl<'a> From>> for DecodeError<'a> { + fn from(e: nom::Err>) -> DecodeError<'a> { + match e { + nom::Err::Incomplete(_) => DecodeError::IncompleteInput, + nom::Err::Error(e) | nom::Err::Failure(e) => DecodeError::NomError(e.input, e.code), + } + } +} diff --git a/src/dec/mod.rs b/src/dec/mod.rs index dca9fec..f872617 100644 --- a/src/dec/mod.rs +++ b/src/dec/mod.rs @@ -7,7 +7,7 @@ use std::collections::HashMap; use crate::crypto; pub use decode::*; -pub use error::TypeError; +pub use error::*; /// A parsed nettext term, with many helpers for destructuring /// @@ -17,7 +17,7 @@ pub use error::TypeError; /// is returned by one of the helper functions, or 'static when first returned from /// `decode()` #[derive(Eq, PartialEq, Debug)] -pub struct Term<'a, 'b>(AnyTerm<'a, 'b>); +pub struct Term<'a, 'b>(pub(crate) AnyTerm<'a, 'b>); #[derive(Eq, PartialEq, Clone)] pub(crate) enum AnyTerm<'a, 'b> { @@ -501,7 +501,7 @@ impl<'a, 'b> AnyTerm<'a, 'b> { } } - fn mkref(&self) -> AnyTerm<'a, '_> { + pub(crate) fn mkref(&self) -> AnyTerm<'a, '_> { match &self { AnyTerm::Str(s) => AnyTerm::Str(s), AnyTerm::Dict(r, d) => AnyTerm::DictRef(r, d), diff --git a/src/enc/mod.rs b/src/enc/mod.rs index 37e300a..73ab16c 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -19,8 +19,8 @@ mod error; +use std::borrow::{Borrow, Cow}; use std::collections::HashMap; -use std::borrow::{Cow, Borrow}; use crate::dec::{self, decode}; use crate::{is_string_char, is_whitespace}; diff --git a/src/serde/de.rs b/src/serde/de.rs new file mode 100644 index 0000000..cf4fa4d --- /dev/null +++ b/src/serde/de.rs @@ -0,0 +1,493 @@ +use std::collections::{hash_map, HashMap}; + +use serde::de::{ + self, DeserializeSeed, EnumAccess, Error as DeError, MapAccess, SeqAccess, VariantAccess, + Visitor, +}; +use serde::Deserialize; + +use crate::dec::debug as fmtdebug; +use crate::dec::*; +use crate::serde::error::{Error, Result}; + +pub struct Deserializer<'de, 'a>(Term<'de, 'a>); + +impl<'de, 'a> Deserializer<'de, 'a> { + pub fn from_term(input: &'a Term<'de, 'a>) -> Deserializer<'de, 'a> { + if let Ok(nested) = input.nested() { + Deserializer(nested) + } else { + Deserializer(Term(input.0.mkref())) + } + } +} + +pub fn from_bytes<'a, T>(s: &'a [u8]) -> Result +where + T: Deserialize<'a>, +{ + let term = decode(s)?; + let mut deserializer = Deserializer::from_term(&term); + Ok(T::deserialize(&mut deserializer)?) +} + +// ---- + +impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de, 'a> { + type Error = Error; + + fn deserialize_any(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(Error::custom("derialize_any not supported")) + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let value = match self.0.string()? { + "true" => true, + "false" => false, + _ => return Err(Error::custom("Invalid boolean litteral")), + }; + visitor.visit_bool(value) + } + + fn deserialize_i8(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_i8(self.0.string()?.parse()?) + } + + fn deserialize_i16(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_i16(self.0.string()?.parse()?) + } + + fn deserialize_i32(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_i32(self.0.string()?.parse()?) + } + + fn deserialize_i64(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_i64(self.0.string()?.parse()?) + } + + fn deserialize_u8(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_u8(self.0.string()?.parse()?) + } + + fn deserialize_u16(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_u16(self.0.string()?.parse()?) + } + + fn deserialize_u32(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_u32(self.0.string()?.parse()?) + } + + fn deserialize_u64(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_u64(self.0.string()?.parse()?) + } + + fn deserialize_f32(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_f32(self.0.string()?.parse()?) + } + + fn deserialize_f64(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_f64(self.0.string()?.parse()?) + } + + fn deserialize_char(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let s = String::from_utf8(self.0.bytes()?)?; + let mut chars = s.chars(); + let c = chars + .next() + .ok_or(Error::custom("invalid char: empty string"))?; + if chars.next().is_some() { + Err(Error::custom("invalid char: too many chars")) + } else { + visitor.visit_char(c) + } + } + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let s = String::from_utf8(self.0.bytes()?)?; + visitor.visit_string(s) + } + + fn deserialize_string(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_str(visitor) + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_byte_buf(self.0.bytes()?) + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_byte_buf(self.0.bytes()?) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.0.dict_of::<0, &[u8]>([], false).is_ok() { + visitor.visit_none() + } else { + visitor.visit_some(self) + } + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.0.dict_of::<0, &[u8]>([], false).is_ok() { + visitor.visit_unit() + } else { + Err(Error::custom(format!( + "Expected unit, got: `{}`", + fmtdebug(self.0.raw()) + ))) + } + } + + fn deserialize_unit_struct(self, name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.0.string()? == name { + visitor.visit_unit() + } else { + Err(Error::custom(format!( + "Expected {}, got: `{}`", + name, + fmtdebug(self.0.raw()) + ))) + } + } + + fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + let [variant, args] = self.0.list_of_first()?; + if variant.string()? == name { + visitor.visit_newtype_struct(&mut Deserializer(args)) + } else { + Err(Error::custom(format!( + "Expected {}, got: `{}`", + name, + fmtdebug(variant.raw()) + ))) + } + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(&mut Seq(&self.0.list())) + } + + fn deserialize_tuple(self, _len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(&mut Seq(&self.0.list())) + } + + fn deserialize_tuple_struct( + self, + name: &'static str, + _len: usize, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + let [variant, args] = self.0.list_of_first()?; + if variant.string()? == name { + visitor.visit_seq(&mut Seq(&args.list())) + } else { + Err(Error::custom(format!( + "Expected {}, got: `{}`", + name, + fmtdebug(variant.raw()) + ))) + } + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_map(&mut Dict::from_term(&self.0)?) + } + + fn deserialize_struct( + self, + name: &'static str, + _fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + let [variant, data] = self.0.list_of()?; + if variant.string()? != name { + return Err(Error::custom(format!( + "Expected {}, got: `{}`", + name, + fmtdebug(variant.raw()) + ))); + } + + Deserializer(data).deserialize_map(visitor) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_enum(Enum(&self.0)) + } + + fn deserialize_identifier(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let s = self.0.string()?; + match s.split_once('.') { + Some((_a, b)) => visitor.visit_borrowed_str(b), + None => visitor.visit_borrowed_str(s), + } + } + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_any(visitor) + } +} + +pub struct Seq<'de, 'a>(&'a [Term<'de, 'a>]); + +impl<'de, 'a> SeqAccess<'de> for &'a mut Seq<'de, 'a> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: DeserializeSeed<'de>, + { + // Check if there are no more elements. + if self.0.is_empty() { + return Ok(None); + } + let first = Term(self.0[0].0.mkref()); + self.0 = &self.0[1..]; + + seed.deserialize(&mut Deserializer(first)).map(Some) + } +} + +pub struct Dict<'de, 'a>( + &'a HashMap<&'de [u8], AnyTerm<'de, 'a>>, + Option<&'de [u8]>, + hash_map::Keys<'a, &'de [u8], AnyTerm<'de, 'a>>, +); + +impl<'de, 'a> Dict<'de, 'a> { + fn from_term(t: &'a Term<'de, 'a>) -> Result> { + match t.0.mkref() { + AnyTerm::DictRef(_, d) => Ok(Dict(d, None, d.keys())), + _ => Err(Error::custom("expected a DICT")), + } + } +} + +impl<'de, 'a> MapAccess<'de> for &'a mut Dict<'de, 'a> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: DeserializeSeed<'de>, + { + match self.2.next() { + None => Ok(None), + Some(k) => { + self.1 = Some(k); + seed.deserialize(&mut Deserializer(Term(AnyTerm::Str(k)))) + .map(Some) + } + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + let k = self.1.ok_or(Error::custom("invald mapaccess order"))?; + let v = self + .0 + .get(k) + .ok_or(Error::custom("invald mapaccess order"))?; + seed.deserialize(&mut Deserializer(Term(v.mkref()))) + } +} + +// ---- + +struct Enum<'de, 'a>(&'a Term<'de, 'a>); + +impl<'de, 'a> EnumAccess<'de> for Enum<'de, 'a> { + type Error = Error; + type Variant = Self; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> + where + V: DeserializeSeed<'de>, + { + let value = seed.deserialize(&mut Deserializer::from_term(&self.0.list()[0]))?; + Ok((value, self)) + } +} + +impl<'de, 'a> VariantAccess<'de> for Enum<'de, 'a> { + type Error = Error; + + fn unit_variant(self) -> Result<()> { + if self.0.list().len() > 1 { + Err(Error::custom("Spurrious data in unit variant")) + } else { + Ok(()) + } + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: DeserializeSeed<'de>, + { + let [_, rest] = self.0.list_of_first()?; + seed.deserialize(&mut Deserializer::from_term(&rest)) + } + + fn tuple_variant(self, _len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + let [_, rest] = self.0.list_of_first()?; + visitor.visit_seq(&mut Seq(&rest.list())) + } + + fn struct_variant(self, _fields: &'static [&'static str], visitor: V) -> Result + where + V: Visitor<'de>, + { + let [_, rest] = self.0.list_of_first()?; + visitor.visit_map(&mut Dict::from_term(&rest)?) + } +} + +//////////////////////////////////////////////////////////////////////////////// + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_struct() { + #[derive(Deserialize, PartialEq, Debug)] + struct Test { + int: u32, + seq: Vec, + } + + let nt = br#"Test { int = 1, seq = YQ Yg }"#; + let expected = Test { + int: 1, + seq: vec!["a".to_owned(), "b".to_owned()], + }; + assert_eq!(expected, from_bytes(nt).unwrap()); + } + + #[test] + fn test_enum() { + #[derive(Deserialize, PartialEq, Debug)] + enum E { + Unit, + Newtype(u32), + Tuple(u32, u32), + Struct { a: u32 }, + } + + let nt = br#"E.Unit"#; + let expected = E::Unit; + assert_eq!(expected, from_bytes(nt).unwrap()); + eprintln!("UNIT OK"); + + let nt = br#"E.Newtype 1"#; + let expected = E::Newtype(1); + assert_eq!(expected, from_bytes(nt).unwrap()); + eprintln!("NEWTYPE OK"); + + let nt = br#"E.Tuple 1 2"#; + let expected = E::Tuple(1, 2); + assert_eq!(expected, from_bytes(nt).unwrap()); + eprintln!("TUPLE OK"); + + let nt = br#"E.Struct { a = 1 }"#; + let expected = E::Struct { a: 1 }; + assert_eq!(expected, from_bytes(nt).unwrap()); + eprintln!("STRUCT OK"); + } +} diff --git a/src/serde/error.rs b/src/serde/error.rs index dec2609..e10bacf 100644 --- a/src/serde/error.rs +++ b/src/serde/error.rs @@ -11,7 +11,11 @@ pub type Result = std::result::Result; pub enum Error { Message(String), Encode(enc::Error), + Decode(String), Type(dec::TypeError), + ParseInt(std::num::ParseIntError), + ParseFloat(std::num::ParseFloatError), + Utf8(std::string::FromUtf8Error), } impl From for Error { @@ -20,6 +24,12 @@ impl From for Error { } } +impl<'a> From> for Error { + fn from(e: dec::DecodeError) -> Self { + Error::Decode(e.to_string()) + } +} + impl From for Error { fn from(e: dec::TypeError) -> Self { Error::Type(e) @@ -32,6 +42,24 @@ impl ser::Error for Error { } } +impl From for Error { + fn from(x: std::num::ParseIntError) -> Error { + Error::ParseInt(x) + } +} + +impl From for Error { + fn from(x: std::num::ParseFloatError) -> Error { + Error::ParseFloat(x) + } +} + +impl From for Error { + fn from(x: std::string::FromUtf8Error) -> Error { + Error::Utf8(x) + } +} + impl de::Error for Error { fn custom(msg: T) -> Self { Error::Message(msg.to_string()) @@ -43,7 +71,11 @@ impl Display for Error { match self { Error::Message(msg) => formatter.write_str(msg), Error::Encode(err) => write!(formatter, "Encode: {}", err), + Error::Decode(err) => write!(formatter, "Decode: {}", err), Error::Type(err) => write!(formatter, "Type: {}", err), + Error::ParseInt(err) => write!(formatter, "Parse (int): {}", err), + Error::ParseFloat(err) => write!(formatter, "Parse (float): {}", err), + Error::Utf8(err) => write!(formatter, "Invalid UTF-8 byte sequnence: {}", err), } } } diff --git a/src/serde/mod.rs b/src/serde/mod.rs index 8062f43..74f4d35 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -1,7 +1,7 @@ -//mod de; +mod de; mod error; mod ser; -//pub use de::{from_str, Deserializer}; +pub use de::{from_bytes, Deserializer}; pub use error::{Error, Result}; -pub use ser::{to_bytes, Serializer}; +pub use ser::{to_bytes, to_term, Serializer}; diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 590c5ec..9f51d9d 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -6,6 +6,13 @@ use serde::ser::Error as SerError; pub struct Serializer; +pub fn to_term(value: &T) -> Result> +where + T: Serialize, +{ + Ok(value.serialize(&mut Serializer)?) +} + pub fn to_bytes(value: &T) -> Result> where T: Serialize, @@ -94,7 +101,7 @@ impl<'a> ser::Serializer for &'a mut Serializer { where T: ?Sized + Serialize, { - value.serialize(self) + Ok(value.serialize(self)?.nested()) } fn serialize_unit(self) -> Result { @@ -142,7 +149,9 @@ impl<'a> ser::Serializer for &'a mut Serializer { } fn serialize_tuple(self, len: usize) -> Result { - Ok(SeqSerializer { items: Vec::with_capacity(len) }) + Ok(SeqSerializer { + items: Vec::with_capacity(len), + }) } fn serialize_tuple_struct( @@ -150,11 +159,9 @@ impl<'a> ser::Serializer for &'a mut Serializer { name: &'static str, len: usize, ) -> Result { - let mut items = Vec::with_capacity(len + 1); - items.push(string(name)?); - Ok(SeqSerializer { - items, - }) + let mut items = Vec::with_capacity(len + 1); + items.push(string(name)?); + Ok(SeqSerializer { items }) } fn serialize_tuple_variant( @@ -164,11 +171,9 @@ impl<'a> ser::Serializer for &'a mut Serializer { variant: &'static str, len: usize, ) -> Result { - let mut items = Vec::with_capacity(len + 1); - items.push(string_owned(format!("{}.{}", name, variant))?); - Ok(SeqSerializer { - items, - }) + let mut items = Vec::with_capacity(len + 1); + items.push(string_owned(format!("{}.{}", name, variant))?); + Ok(SeqSerializer { items }) } fn serialize_map(self, _len: Option) -> Result { @@ -360,7 +365,7 @@ impl<'a> ser::SerializeStructVariant for StructVariantSerializer { mod tests { use super::*; use serde::Serialize; - use std::collections::HashMap; + use std::collections::HashMap; #[test] fn test_struct() { @@ -430,8 +435,8 @@ mod tests { #[test] fn test_dict() { let mut d = HashMap::new(); - d.insert("hello", "world"); - d.insert("dont", "panic"); + d.insert("hello", "world"); + d.insert("dont", "panic"); let expected = br#"{ ZG9udA = cGFuaWM, aGVsbG8 = d29ybGQ, @@ -439,8 +444,8 @@ mod tests { assert_eq!(&to_bytes(&d).unwrap(), expected); let mut d = HashMap::new(); - d.insert(12, vec![42, 125]); - d.insert(33, vec![19, 22, 21]); + d.insert(12, vec![42, 125]); + d.insert(33, vec![19, 22, 21]); let expected = br#"{ 12 = 42 125, 33 = 19 22 21,