From bc319d3089d83586824cfac98a9ffe9d10664d59 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 18 Nov 2022 12:49:13 +0100 Subject: [PATCH] Fix stupid encode API --- src/crypto/b2.rs | 4 +- src/crypto/ed25519.rs | 16 ++--- src/enc/mod.rs | 150 +++++++++++++++++++----------------------- src/lib.rs | 22 +++---- 4 files changed, 89 insertions(+), 103 deletions(-) diff --git a/src/crypto/b2.rs b/src/crypto/b2.rs index 48e52ec..120af38 100644 --- a/src/crypto/b2.rs +++ b/src/crypto/b2.rs @@ -36,7 +36,7 @@ impl Blake2Sum { } impl enc::Encode for Blake2Sum { - fn term(&self) -> enc::Term<'_> { - enc::bytes(self.as_bytes()) + fn term(&self) -> enc::Result<'_> { + Ok(enc::bytes(self.as_bytes())) } } diff --git a/src/crypto/ed25519.rs b/src/crypto/ed25519.rs index 8c73e69..ce477e4 100644 --- a/src/crypto/ed25519.rs +++ b/src/crypto/ed25519.rs @@ -10,25 +10,25 @@ pub fn generate_keypair() -> Keypair { } impl enc::Encode for Keypair { - fn term(&self) -> enc::Term<'_> { - enc::bytes(&self.to_bytes()) + fn term(&self) -> enc::Result<'_> { + Ok(enc::bytes(&self.to_bytes())) } } impl enc::Encode for PublicKey { - fn term(&self) -> enc::Term<'_> { - enc::bytes(self.as_bytes()) + fn term(&self) -> enc::Result<'_> { + Ok(enc::bytes(self.as_bytes())) } } impl enc::Encode for SecretKey { - fn term(&self) -> enc::Term<'_> { - enc::bytes(self.as_bytes()) + fn term(&self) -> enc::Result<'_> { + Ok(enc::bytes(self.as_bytes())) } } impl enc::Encode for Signature { - fn term(&self) -> enc::Term<'_> { - enc::bytes(&self.to_bytes()) + fn term(&self) -> enc::Result<'_> { + Ok(enc::bytes(&self.to_bytes())) } } diff --git a/src/enc/mod.rs b/src/enc/mod.rs index 711bff7..eb7218b 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -13,7 +13,6 @@ enum T<'a> { OwnedStr(Vec), Dict(HashMap<&'a [u8], T<'a>>), List(Vec>), - Err(Error), } /// An error that happenned when creating a nettext encoder term @@ -25,16 +24,18 @@ pub enum Error { ListInList, } +pub type Result<'a> = std::result::Result, Error>; + // ---- helpers to transform datatypes into encoder terms ---- /// Trait for anything that can be encoded as nettext pub trait Encode { - fn term(&self) -> Term<'_>; + fn term(&self) -> Result<'_>; } impl<'a, 'b> Encode for dec::Term<'a, 'b> { - fn term(&self) -> Term<'_> { - Term(T::Str(self.raw())) + fn term(&self) -> Result<'_> { + Ok(Term(T::Str(self.raw()))) } } @@ -45,15 +46,15 @@ impl<'a, 'b> Encode for dec::Term<'a, 'b> { /// ``` /// use nettext::enc::*; /// -/// assert_eq!(encode(string("Hello world .")).unwrap(), b"Hello world ."); +/// assert_eq!(encode(string("Hello world .").unwrap()), b"Hello world ."); /// ``` -pub fn string(s: &str) -> Term<'_> { +pub fn string(s: &str) -> Result<'_> { for c in s.as_bytes().iter() { if !(is_string_char(*c) || is_whitespace(*c)) { - return Term(T::Err(Error::InvalidCharacter(*c))); + return Err(Error::InvalidCharacter(*c)); } } - Term(T::Str(s.as_bytes())) + Ok(Term(T::Str(s.as_bytes()))) } /// Include a raw nettext value @@ -61,13 +62,13 @@ pub fn string(s: &str) -> Term<'_> { /// ``` /// use nettext::enc::*; /// -/// assert_eq!(encode(raw(b"Hello { a = b, c = d} .")).unwrap(), b"Hello { a = b, c = d} ."); +/// assert_eq!(encode(raw(b"Hello { a = b, c = d} .").unwrap()), b"Hello { a = b, c = d} ."); /// ``` -pub fn raw(bytes: &[u8]) -> Term<'_> { +pub fn raw(bytes: &[u8]) -> Result<'_> { if decode(bytes).is_err() { - return Term(T::Err(Error::InvalidRaw)); + return Err(Error::InvalidRaw); } - Term(T::Str(bytes)) + Ok(Term(T::Str(bytes))) } /// Term corresponding to a list of terms @@ -76,20 +77,19 @@ pub fn raw(bytes: &[u8]) -> Term<'_> { /// use nettext::enc::*; /// /// assert_eq!(encode(list([ -/// string("Hello"), -/// string("world") -/// ])).unwrap(), b"Hello world"); +/// string("Hello").unwrap(), +/// string("world").unwrap() +/// ]).unwrap()), b"Hello world"); /// ``` -pub fn list<'a, I: IntoIterator>>(terms: I) -> Term<'a> { +pub fn list<'a, I: IntoIterator>>(terms: I) -> Result<'a> { let mut tmp = Vec::with_capacity(8); for t in terms { match t.0 { - T::Err(e) => return Term(T::Err(e)), - T::List(_) => return Term(T::Err(Error::ListInList)), + T::List(_) => return Err(Error::ListInList), x => tmp.push(x), } } - Term(T::List(tmp)) + Ok(Term(T::List(tmp))) } /// Term corresponding to a dictionnary of items @@ -98,19 +98,14 @@ pub fn list<'a, I: IntoIterator>>(terms: I) -> Term<'a> { /// use nettext::enc::*; /// /// assert_eq!(encode(dict([ -/// ("a", string("Hello")), -/// ("b", string("world")) -/// ])).unwrap(), b"{\n a = Hello,\n b = world,\n}"); +/// ("a", string("Hello").unwrap()), +/// ("b", string("world").unwrap()) +/// ])), b"{\n a = Hello,\n b = world,\n}"); /// ``` pub fn dict<'a, I: IntoIterator)>>(pairs: I) -> Term<'a> { let mut tmp = HashMap::new(); for (k, v) in pairs { - match v.0 { - T::Err(e) => return Term(T::Err(e)), - vv => { - tmp.insert(k.as_bytes(), vv); - } - } + tmp.insert(k.as_bytes(), v.0); } Term(T::Dict(tmp)) } @@ -123,7 +118,7 @@ pub fn dict<'a, I: IntoIterator)>>(pairs: I) -> Term<' /// ``` /// use nettext::enc::*; /// -/// assert_eq!(encode(bytes(b"hello, world!")).unwrap(), b"aGVsbG8sIHdvcmxkIQ"); +/// assert_eq!(encode(bytes(b"hello, world!")), b"aGVsbG8sIHdvcmxkIQ"); /// ``` pub fn bytes(bytes: &[u8]) -> Term<'static> { Term(T::OwnedStr( @@ -152,31 +147,24 @@ impl<'a> Term<'a> { /// Transforms the initial term into a list if necessary. #[must_use] pub fn append(self, t: Term<'a>) -> Term<'a> { - match t.0 { - T::Err(e) => Term(T::Err(e)), - tt => match self.0 { - T::List(mut v) => { - v.push(tt); - Term(T::List(v)) - } - x => Term(T::List(vec![x, tt])), - }, + match self.0 { + T::List(mut v) => { + v.push(t.0); + Term(T::List(v)) + } + x => Term(T::List(vec![x, t.0])), } } /// Inserts a key-value pair into a term that is a dictionnary. /// Fails if `self` is not a dictionnary. - #[must_use] - pub fn insert(self, k: &'a str, v: Term<'a>) -> Term<'a> { - match v.0 { - T::Err(e) => Term(T::Err(e)), - vv => match self.0 { - T::Dict(mut d) => { - d.insert(k.as_bytes(), vv); - Term(T::Dict(d)) - } - _ => Term(T::Err(Error::NotADictionnary)), - }, + pub fn insert(self, k: &'a str, v: Term<'a>) -> Result<'a> { + match self.0 { + T::Dict(mut d) => { + d.insert(k.as_bytes(), v.0); + Ok(Term(T::Dict(d))) + } + _ => Err(Error::NotADictionnary), } } @@ -188,7 +176,7 @@ impl<'a> Term<'a> { /// ``` /// use nettext::enc::*; /// - /// assert_eq!(encode(list([string("hello"), string("world")]).nested()).unwrap(), b"{ . = hello world }"); + /// assert_eq!(encode(list([string("hello").unwrap(), string("world").unwrap()]).unwrap().nested()), b"{ . = hello world }"); /// ``` #[must_use] pub fn nested(self) -> Term<'a> { @@ -199,18 +187,13 @@ impl<'a> Term<'a> { // ---- encoding function ---- /// Generate the nettext representation of a term -pub fn encode(t: Term<'_>) -> Result, Error> { +pub fn encode(t: Term<'_>) -> Vec { let mut buf = Vec::with_capacity(128); - encode_aux(&mut buf, t.0, 0, true)?; - Ok(buf) + encode_aux(&mut buf, t.0, 0, true); + buf } -fn encode_aux( - buf: &mut Vec, - term: T<'_>, - indent: usize, - is_toplevel: bool, -) -> Result<(), Error> { +fn encode_aux(buf: &mut Vec, term: T<'_>, indent: usize, is_toplevel: bool) { match term { T::Str(s) => buf.extend_from_slice(s), T::OwnedStr(s) => buf.extend_from_slice(&s), @@ -222,7 +205,7 @@ fn encode_aux( let (k, v) = d.into_iter().next().unwrap(); buf.extend_from_slice(k); buf.extend_from_slice(b" = "); - encode_aux(buf, v, indent + 2, false)?; + encode_aux(buf, v, indent + 2, false); buf.extend_from_slice(b" }"); } else { buf.extend_from_slice(b"{\n"); @@ -236,7 +219,7 @@ fn encode_aux( } buf.extend_from_slice(k); buf.extend_from_slice(b" = "); - encode_aux(buf, v, indent2, false)?; + encode_aux(buf, v, indent2, false); buf.extend_from_slice(b",\n"); } for _ in 0..indent { @@ -256,12 +239,10 @@ fn encode_aux( } else if i > 0 { buf.push(b' '); } - encode_aux(buf, v, indent2, is_toplevel)?; + encode_aux(buf, v, indent2, is_toplevel); } } - T::Err(e) => return Err(e), } - Ok(()) } #[cfg(test)] @@ -271,20 +252,21 @@ mod tests { #[test] fn complex1() { let input = list([ - string("HELLO"), - string("alexhelloworld"), + string("HELLO").unwrap(), + string("alexhelloworld").unwrap(), dict([ - ("from", string("jxx")), - ("subject", string("hello")), - ("data", raw(b"{ f1 = plop, f2 = kuko }")), + ("from", string("jxx").unwrap()), + ("subject", string("hello").unwrap()), + ("data", raw(b"{ f1 = plop, f2 = kuko }").unwrap()), ]), - ]); + ]) + .unwrap(); let expected = b"HELLO alexhelloworld { data = { f1 = plop, f2 = kuko }, from = jxx, subject = hello, }"; - let enc = encode(input).unwrap(); + let enc = encode(input); eprintln!("{}", std::str::from_utf8(&enc).unwrap()); eprintln!("{}", std::str::from_utf8(&expected[..]).unwrap()); assert_eq!(&enc, &expected[..]); @@ -292,20 +274,24 @@ mod tests { #[test] fn nested() { - assert!(encode(list([ - string("a"), - string("b"), - list([string("c"), string("d")]) - ])) + assert!(list([ + string("a").unwrap(), + string("b").unwrap(), + list([string("c").unwrap(), string("d").unwrap()]).unwrap() + ]) .is_err()); assert_eq!( - encode(list([ - string("a"), - string("b"), - list([string("c"), string("d")]).nested() - ])) - .unwrap(), + encode( + list([ + string("a").unwrap(), + string("b").unwrap(), + list([string("c").unwrap(), string("d").unwrap()]) + .unwrap() + .nested() + ]) + .unwrap() + ), b"a b { . = c d }" ); } diff --git a/src/lib.rs b/src/lib.rs index c565bc5..cfbd447 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,16 +9,16 @@ //! //! // Encode a fist object that represents a payload that will be hashed and signed //! let text1 = encode(list([ -//! string("CALL"), -//! string("myfunction"), +//! string("CALL").unwrap(), +//! string("myfunction").unwrap(), //! dict([ -//! ("a", string("hello")), -//! ("b", string("world")), -//! ("c", raw(b"{ a = 12, b = 42 }")), +//! ("a", string("hello").unwrap()), +//! ("b", string("world").unwrap()), +//! ("c", raw(b"{ a = 12, b = 42 }").unwrap()), //! ("d", bytes_split(&((0..128u8).collect::>()))), //! ]), -//! keypair.public.term(), -//! ])).unwrap(); +//! keypair.public.term().unwrap(), +//! ]).unwrap()); //! eprintln!("{}", std::str::from_utf8(&text1).unwrap()); //! //! let hash = crypto::Blake2Sum::compute(&text1); @@ -26,10 +26,10 @@ //! //! // Encode a second object that represents the signed and hashed payload //! let text2 = encode(dict([ -//! ("hash", hash.term()), -//! ("signature", sign.term()), -//! ("payload", raw(&text1)), -//! ])).unwrap(); +//! ("hash", hash.term().unwrap()), +//! ("signature", sign.term().unwrap()), +//! ("payload", raw(&text1).unwrap()), +//! ])); //! eprintln!("{}", std::str::from_utf8(&text2).unwrap()); //! //! // Decode and check everything is fine