From 7a79994b0053bce3a348cc6f7bcfdf02052e0217 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 18 Nov 2022 18:59:06 +0100 Subject: [PATCH] Cleanup --- src/dec/mod.rs | 14 ++----- src/enc/error.rs | 2 + src/enc/mod.rs | 73 ++++++++++++++++++-------------- src/lib.rs | 4 +- src/serde/de.rs | 65 ++++------------------------- src/serde/mod.rs | 21 +++++----- src/serde/ser.rs | 105 +++-------------------------------------------- 7 files changed, 73 insertions(+), 211 deletions(-) diff --git a/src/dec/mod.rs b/src/dec/mod.rs index f872617..6982cb0 100644 --- a/src/dec/mod.rs +++ b/src/dec/mod.rs @@ -5,7 +5,9 @@ mod error; use std::collections::HashMap; +#[cfg(any(feature = "blake2", feature = "ed25519-dalek"))] use crate::crypto; + pub use decode::*; pub use error::*; @@ -578,16 +580,6 @@ impl<'a, 'b> std::fmt::Debug for AnyTerm<'a, 'b> { impl<'a, 'b> std::fmt::Debug for NonListTerm<'a, 'b> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { - match self.mkref() { - NonListTerm::Str(s) => write!(f, "Str(`{}`)", debug(s)), - NonListTerm::DictRef(raw, d) => { - write!(f, "Dict<`{}`", debug(raw))?; - for (k, v) in d.iter() { - write!(f, "\n `{}`={:?}", debug(k), v)?; - } - write!(f, ">") - } - _ => unreachable!(), - } + AnyTerm::from(self.mkref()).fmt(f) } } diff --git a/src/enc/error.rs b/src/enc/error.rs index 78f9a15..487b088 100644 --- a/src/enc/error.rs +++ b/src/enc/error.rs @@ -6,6 +6,7 @@ pub enum Error { InvalidCharacter(u8), InvalidRaw, NotADictionnary, + DuplicateKey(String), ListInList, } @@ -15,6 +16,7 @@ impl std::fmt::Display for Error { Error::InvalidCharacter(c) => write!(f, "Invalid character '{}'", *c as char), Error::InvalidRaw => write!(f, "Invalid RAW nettext litteral"), Error::NotADictionnary => write!(f, "Tried to insert into a term that isn't a dictionnary"), + Error::DuplicateKey(s) => write!(f, "Duplicate dict key: {}", s), Error::ListInList => write!(f, "Refusing to build nested lists with list(), use either list_flatten() or list_nested()"), } } diff --git a/src/enc/mod.rs b/src/enc/mod.rs index 9909287..4149d35 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -13,7 +13,7 @@ //! ("b", string("world").unwrap()), //! ("c", raw(b"{ a = 12, b = 42 }").unwrap()), //! ("d", bytes_split(&((0..128u8).collect::>()))), -//! ]), +//! ]).unwrap(), //! ]).unwrap().encode(); //! ``` @@ -94,13 +94,6 @@ pub fn raw(bytes: &[u8]) -> Result<'_> { Ok(Term(T::Str(bytes))) } -pub(crate) fn safe_raw(bytes: &[u8]) -> Term<'_> { - Term(T::Str(bytes)) -} -pub(crate) fn safe_raw_owned(bytes: Vec) -> Term<'static> { - Term(T::OwnedStr(bytes)) -} - /// Term corresponding to a list of terms /// /// ``` @@ -154,22 +147,16 @@ pub fn list_nested<'a, I: IntoIterator>>(terms: I) -> Result<'a> /// assert_eq!(dict([ /// ("a", string("Hello").unwrap()), /// ("b", string("world").unwrap()) -/// ]).encode(), b"{\n a = Hello,\n b = world,\n}"); +/// ]).unwrap().encode(), b"{\n a = Hello,\n b = world,\n}"); /// ``` -pub fn dict<'a, I: IntoIterator)>>(pairs: I) -> Term<'a> { +pub fn dict<'a, I: IntoIterator)>>(pairs: I) -> Result<'a> { let mut tmp = HashMap::new(); for (k, v) in pairs { - tmp.insert(Cow::from(k.as_bytes()), v.0); + if tmp.insert(Cow::from(k.as_bytes()), v.0).is_some() { + return Err(Error::DuplicateKey(k.to_string())); + } } - Term(T::Dict(tmp)) -} - -pub(crate) fn dict_owned_u8<'a, I: IntoIterator, Term<'a>)>>(pairs: I) -> Term<'a> { - let mut tmp = HashMap::new(); - for (k, v) in pairs { - tmp.insert(Cow::from(k), v.0); - } - Term(T::Dict(tmp)) + Ok(Term(T::Dict(tmp))) } /// Term corresponding to a byte slice, @@ -223,7 +210,9 @@ impl<'a> Term<'a> { pub fn insert(self, k: &'a str, v: Term<'a>) -> Result<'a> { match self.0 { T::Dict(mut d) => { - d.insert(Cow::from(k.as_bytes()), v.0); + if d.insert(Cow::from(k.as_bytes()), v.0).is_some() { + return Err(Error::DuplicateKey(k.to_string())); + } Ok(Term(T::Dict(d))) } _ => Err(Error::NotADictionnary), @@ -242,10 +231,33 @@ impl<'a> Term<'a> { /// ``` #[must_use] pub fn nested(self) -> Term<'a> { - dict([(".", self)]) + dict([(".", self)]).unwrap() } } +// ---- additional internal functions for serde module ---- + +#[cfg(feature = "serde")] +pub(crate) fn dict_owned_u8<'a, I: IntoIterator, Term<'a>)>>( + pairs: I, +) -> Result<'a> { + let mut tmp = HashMap::new(); + for (k, v) in pairs { + tmp.insert(Cow::from(k), v.0); + } + Ok(Term(T::Dict(tmp))) +} + +#[cfg(feature = "serde")] +pub(crate) fn safe_raw(bytes: &[u8]) -> Term<'_> { + Term(T::Str(bytes)) +} + +#[cfg(feature = "serde")] +pub(crate) fn safe_raw_owned(bytes: Vec) -> Term<'static> { + Term(T::OwnedStr(bytes)) +} + // ---- encoding function ---- impl<'a> Term<'a> { @@ -267,13 +279,13 @@ impl<'a> T<'a> { buf.extend_from_slice(b"{}"); } else if d.len() == 1 { let (k, v) = d.into_iter().next().unwrap(); - if k.as_ref() == b"." { - buf.extend_from_slice(b"{.= "); - } else { - buf.extend_from_slice(b"{ "); - buf.extend_from_slice(k.borrow()); - buf.extend_from_slice(b" = "); - } + if k.as_ref() == b"." { + buf.extend_from_slice(b"{.= "); + } else { + buf.extend_from_slice(b"{ "); + buf.extend_from_slice(k.borrow()); + buf.extend_from_slice(b" = "); + } v.encode_aux(buf, indent + 2, false); buf.extend_from_slice(b" }"); } else { @@ -328,7 +340,8 @@ mod tests { ("from", string("jxx").unwrap()), ("subject", string("hello").unwrap()), ("data", raw(b"{ f1 = plop, f2 = kuko }").unwrap()), - ]), + ]) + .unwrap(), ]) .unwrap(); let expected = b"HELLO alexhelloworld { diff --git a/src/lib.rs b/src/lib.rs index ecd5161..1c96dac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,7 +16,7 @@ //! ("b", string("world").unwrap()), //! ("c", raw(b"{ a = 12, b = 42 }").unwrap()), //! ("d", bytes_split(&((0..128u8).collect::>()))), -//! ]), +//! ]).unwrap(), //! keypair.public.term().unwrap(), //! ]).unwrap().encode(); //! eprintln!("{}", std::str::from_utf8(&text1).unwrap()); @@ -29,7 +29,7 @@ //! ("hash", hash.term().unwrap()), //! ("signature", sign.term().unwrap()), //! ("payload", raw(&text1).unwrap()), -//! ]).encode(); +//! ]).unwrap().encode(); //! eprintln!("{}", std::str::from_utf8(&text2).unwrap()); //! //! // Decode and check everything is fine diff --git a/src/serde/de.rs b/src/serde/de.rs index 7278fc6..1e75a5d 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -40,7 +40,10 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de, 'a> { where V: Visitor<'de>, { - Err(Error::custom("derialize_any not supported")) + Err(Error::custom(format!( + "deserialize_any not supported, at: `{}`", + self.0 + ))) } fn deserialize_bool(self, visitor: V) -> Result @@ -400,9 +403,9 @@ impl<'de, 'a> EnumAccess<'de> for Enum<'de, 'a> { where V: DeserializeSeed<'de>, { - let term = &self.0.list()[0]; - let value = seed.deserialize(&mut Deserializer(Term(term.0.mkref())))?; - Ok((value, self)) + let variant = &self.0.list()[0]; + let variant = seed.deserialize(&mut Deserializer(Term(variant.0.mkref())))?; + Ok((variant, self)) } } @@ -441,57 +444,3 @@ impl<'de, 'a> VariantAccess<'de> for Enum<'de, 'a> { visitor.visit_map(&mut Dict::from_term(&rest)?) } } - -// ---- - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_struct() { - #[derive(Deserialize, PartialEq, Debug)] - struct Test { - int: u32, - seq: Vec, - } - - let nt = br#"Test { int = 1, seq = YQ Yg }"#; - let expected = Test { - int: 1, - seq: vec!["a".to_owned(), "b".to_owned()], - }; - assert_eq!(expected, from_bytes(nt).unwrap()); - } - - #[test] - fn test_enum() { - #[derive(Deserialize, PartialEq, Debug)] - enum E { - Unit, - Newtype(u32), - Tuple(u32, u32), - Struct { a: u32 }, - } - - let nt = br#"E.Unit"#; - let expected = E::Unit; - assert_eq!(expected, from_bytes(nt).unwrap()); - eprintln!("UNIT OK"); - - let nt = br#"E.Newtype 1"#; - let expected = E::Newtype(1); - assert_eq!(expected, from_bytes(nt).unwrap()); - eprintln!("NEWTYPE OK"); - - let nt = br#"E.Tuple 1 2"#; - let expected = E::Tuple(1, 2); - assert_eq!(expected, from_bytes(nt).unwrap()); - eprintln!("TUPLE OK"); - - let nt = br#"E.Struct { a = 1 }"#; - let expected = E::Struct { a: 1 }; - assert_eq!(expected, from_bytes(nt).unwrap()); - eprintln!("STRUCT OK"); - } -} diff --git a/src/serde/mod.rs b/src/serde/mod.rs index 2b80470..b5bcddf 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -41,15 +41,15 @@ mod tests { test_bidir(input, expected); let input = vec![ - Test { - int: 1, - seq: vec!["a".to_string(), "b".to_string()], - }, - Test { - int: 2, - seq: vec!["c".to_string(), "d".to_string()], - }, - ]; + Test { + int: 1, + seq: vec!["a".to_string(), "b".to_string()], + }, + Test { + int: 2, + seq: vec!["c".to_string(), "d".to_string()], + }, + ]; let expected = br#"{.= Test { int = 1, seq = YQ Yg, @@ -94,7 +94,8 @@ mod tests { E::Struct { a: 1 }, E::Tuple(3, 2), ]; - let expected = br#"E.Unit E.Unit {.= E.Newtype 1 } {.= E.Tuple 1 2 } {.= E.Struct { a = 1 } } {.= + let expected = + br#"E.Unit E.Unit {.= E.Newtype 1 } {.= E.Tuple 1 2 } {.= E.Struct { a = 1 } } {.= E.Tuple 3 2 }"#; test_bidir(input, expected); } diff --git a/src/serde/ser.rs b/src/serde/ser.rs index e26aabf..e4c43d2 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -94,7 +94,7 @@ impl<'a> ser::Serializer for &'a mut Serializer { } fn serialize_none(self) -> Result { - Ok(dict([])) + Ok(dict([])?) } fn serialize_some(self, value: &T) -> Result @@ -105,7 +105,7 @@ impl<'a> ser::Serializer for &'a mut Serializer { } fn serialize_unit(self) -> Result { - Ok(dict([])) + Ok(dict([])?) } fn serialize_unit_struct(self, name: &'static str) -> Result { @@ -309,7 +309,7 @@ impl ser::SerializeMap for MapSerializer { } fn end(self) -> Result> { - Ok(dict_owned_u8(self.fields.into_iter())) + Ok(dict_owned_u8(self.fields.into_iter())?) } } @@ -331,7 +331,7 @@ impl ser::SerializeStruct for StructSerializer { } fn end(self) -> Result> { - Ok(list([string(self.name)?, dict(self.fields.into_iter())])?) + Ok(list([string(self.name)?, dict(self.fields.into_iter())?])?) } } @@ -356,102 +356,7 @@ impl ser::SerializeStructVariant for StructVariantSerializer { fn end(self) -> Result> { Ok(list([ string_owned(format!("{}.{}", self.name, self.variant))?, - dict(self.fields.into_iter()), + dict(self.fields.into_iter())?, ])?) } } - -// ---- - -#[cfg(test)] -mod tests { - use super::*; - use serde::Serialize; - use std::collections::HashMap; - - #[test] - fn test_struct() { - #[derive(Serialize)] - struct Test { - int: u32, - seq: Vec<&'static str>, - } - - let test = Test { - int: 1, - seq: vec!["a", "b"], - }; - let expected = br#"Test { - int = 1, - seq = YQ Yg, - }"#; - assert_eq!(&to_bytes(&test).unwrap(), expected); - } - - #[test] - fn test_enum() { - #[derive(Serialize)] - enum E { - Unit, - Newtype(u32), - Tuple(u32, u32), - Struct { a: u32 }, - } - - let u = E::Unit; - let expected = br#"E.Unit"#; - assert_eq!(&to_bytes(&u).unwrap(), expected); - - let n = E::Newtype(1); - let expected = br#"E.Newtype 1"#; - assert_eq!(&to_bytes(&n).unwrap(), expected); - - let t = E::Tuple(1, 2); - let expected = br#"E.Tuple 1 2"#; - assert_eq!(&to_bytes(&t).unwrap(), expected); - - let s = E::Struct { a: 1 }; - let expected = br#"E.Struct { a = 1 }"#; - assert_eq!(&to_bytes(&s).unwrap(), expected); - } - - #[test] - fn test_seq() { - let u = (1, 2, 3, 4); - let expected = br#"1 2 3 4"#; - assert_eq!(&to_bytes(&u).unwrap(), expected); - - let n = (1, 2, (2, 3, 4), 5, 6); - let expected = br#"1 2 {.= 2 3 4 } 5 6"#; - assert_eq!(&to_bytes(&n).unwrap(), expected); - - let t = [1, 2, 3, 4]; - let expected = br#"1 2 3 4"#; - assert_eq!(&to_bytes(&t).unwrap(), expected); - - let s = [[1, 2], [2, 3], [3, 4]]; - let expected = br#"{.= 1 2 } {.= 2 3 } {.= 3 4 }"#; - assert_eq!(&to_bytes(&s).unwrap(), expected); - } - - #[test] - fn test_dict() { - let mut d = HashMap::new(); - d.insert("hello", "world"); - d.insert("dont", "panic"); - let expected = br#"{ - ZG9udA = cGFuaWM, - aGVsbG8 = d29ybGQ, -}"#; - assert_eq!(&to_bytes(&d).unwrap(), expected); - - let mut d = HashMap::new(); - d.insert(12, vec![42, 125]); - d.insert(33, vec![19, 22, 21]); - let expected = br#"{ - 12 = 42 125, - 33 = 19 22 21, -}"#; - assert_eq!(&to_bytes(&d).unwrap(), expected); - } -}