Fix stupid encode API

This commit is contained in:
Alex 2022-11-18 12:49:13 +01:00
parent e7ea915121
commit bc319d3089
Signed by: lx
GPG key ID: 0E496D15096376BE
4 changed files with 89 additions and 103 deletions

View file

@ -36,7 +36,7 @@ impl Blake2Sum {
} }
impl enc::Encode for Blake2Sum { impl enc::Encode for Blake2Sum {
fn term(&self) -> enc::Term<'_> { fn term(&self) -> enc::Result<'_> {
enc::bytes(self.as_bytes()) Ok(enc::bytes(self.as_bytes()))
} }
} }

View file

@ -10,25 +10,25 @@ pub fn generate_keypair() -> Keypair {
} }
impl enc::Encode for Keypair { impl enc::Encode for Keypair {
fn term(&self) -> enc::Term<'_> { fn term(&self) -> enc::Result<'_> {
enc::bytes(&self.to_bytes()) Ok(enc::bytes(&self.to_bytes()))
} }
} }
impl enc::Encode for PublicKey { impl enc::Encode for PublicKey {
fn term(&self) -> enc::Term<'_> { fn term(&self) -> enc::Result<'_> {
enc::bytes(self.as_bytes()) Ok(enc::bytes(self.as_bytes()))
} }
} }
impl enc::Encode for SecretKey { impl enc::Encode for SecretKey {
fn term(&self) -> enc::Term<'_> { fn term(&self) -> enc::Result<'_> {
enc::bytes(self.as_bytes()) Ok(enc::bytes(self.as_bytes()))
} }
} }
impl enc::Encode for Signature { impl enc::Encode for Signature {
fn term(&self) -> enc::Term<'_> { fn term(&self) -> enc::Result<'_> {
enc::bytes(&self.to_bytes()) Ok(enc::bytes(&self.to_bytes()))
} }
} }

View file

@ -13,7 +13,6 @@ enum T<'a> {
OwnedStr(Vec<u8>), OwnedStr(Vec<u8>),
Dict(HashMap<&'a [u8], T<'a>>), Dict(HashMap<&'a [u8], T<'a>>),
List(Vec<T<'a>>), List(Vec<T<'a>>),
Err(Error),
} }
/// An error that happenned when creating a nettext encoder term /// An error that happenned when creating a nettext encoder term
@ -25,16 +24,18 @@ pub enum Error {
ListInList, ListInList,
} }
pub type Result<'a> = std::result::Result<Term<'a>, Error>;
// ---- helpers to transform datatypes into encoder terms ---- // ---- helpers to transform datatypes into encoder terms ----
/// Trait for anything that can be encoded as nettext /// Trait for anything that can be encoded as nettext
pub trait Encode { pub trait Encode {
fn term(&self) -> Term<'_>; fn term(&self) -> Result<'_>;
} }
impl<'a, 'b> Encode for dec::Term<'a, 'b> { impl<'a, 'b> Encode for dec::Term<'a, 'b> {
fn term(&self) -> Term<'_> { fn term(&self) -> Result<'_> {
Term(T::Str(self.raw())) Ok(Term(T::Str(self.raw())))
} }
} }
@ -45,15 +46,15 @@ impl<'a, 'b> Encode for dec::Term<'a, 'b> {
/// ``` /// ```
/// use nettext::enc::*; /// use nettext::enc::*;
/// ///
/// assert_eq!(encode(string("Hello world .")).unwrap(), b"Hello world ."); /// assert_eq!(encode(string("Hello world .").unwrap()), b"Hello world .");
/// ``` /// ```
pub fn string(s: &str) -> Term<'_> { pub fn string(s: &str) -> Result<'_> {
for c in s.as_bytes().iter() { for c in s.as_bytes().iter() {
if !(is_string_char(*c) || is_whitespace(*c)) { if !(is_string_char(*c) || is_whitespace(*c)) {
return Term(T::Err(Error::InvalidCharacter(*c))); return Err(Error::InvalidCharacter(*c));
} }
} }
Term(T::Str(s.as_bytes())) Ok(Term(T::Str(s.as_bytes())))
} }
/// Include a raw nettext value /// Include a raw nettext value
@ -61,13 +62,13 @@ pub fn string(s: &str) -> Term<'_> {
/// ``` /// ```
/// use nettext::enc::*; /// use nettext::enc::*;
/// ///
/// assert_eq!(encode(raw(b"Hello { a = b, c = d} .")).unwrap(), b"Hello { a = b, c = d} ."); /// assert_eq!(encode(raw(b"Hello { a = b, c = d} .").unwrap()), b"Hello { a = b, c = d} .");
/// ``` /// ```
pub fn raw(bytes: &[u8]) -> Term<'_> { pub fn raw(bytes: &[u8]) -> Result<'_> {
if decode(bytes).is_err() { if decode(bytes).is_err() {
return Term(T::Err(Error::InvalidRaw)); return Err(Error::InvalidRaw);
} }
Term(T::Str(bytes)) Ok(Term(T::Str(bytes)))
} }
/// Term corresponding to a list of terms /// Term corresponding to a list of terms
@ -76,20 +77,19 @@ pub fn raw(bytes: &[u8]) -> Term<'_> {
/// use nettext::enc::*; /// use nettext::enc::*;
/// ///
/// assert_eq!(encode(list([ /// assert_eq!(encode(list([
/// string("Hello"), /// string("Hello").unwrap(),
/// string("world") /// string("world").unwrap()
/// ])).unwrap(), b"Hello world"); /// ]).unwrap()), b"Hello world");
/// ``` /// ```
pub fn list<'a, I: IntoIterator<Item = Term<'a>>>(terms: I) -> Term<'a> { pub fn list<'a, I: IntoIterator<Item = Term<'a>>>(terms: I) -> Result<'a> {
let mut tmp = Vec::with_capacity(8); let mut tmp = Vec::with_capacity(8);
for t in terms { for t in terms {
match t.0 { match t.0 {
T::Err(e) => return Term(T::Err(e)), T::List(_) => return Err(Error::ListInList),
T::List(_) => return Term(T::Err(Error::ListInList)),
x => tmp.push(x), x => tmp.push(x),
} }
} }
Term(T::List(tmp)) Ok(Term(T::List(tmp)))
} }
/// Term corresponding to a dictionnary of items /// Term corresponding to a dictionnary of items
@ -98,19 +98,14 @@ pub fn list<'a, I: IntoIterator<Item = Term<'a>>>(terms: I) -> Term<'a> {
/// use nettext::enc::*; /// use nettext::enc::*;
/// ///
/// assert_eq!(encode(dict([ /// assert_eq!(encode(dict([
/// ("a", string("Hello")), /// ("a", string("Hello").unwrap()),
/// ("b", string("world")) /// ("b", string("world").unwrap())
/// ])).unwrap(), b"{\n a = Hello,\n b = world,\n}"); /// ])), b"{\n a = Hello,\n b = world,\n}");
/// ``` /// ```
pub fn dict<'a, I: IntoIterator<Item = (&'a str, Term<'a>)>>(pairs: I) -> Term<'a> { pub fn dict<'a, I: IntoIterator<Item = (&'a str, Term<'a>)>>(pairs: I) -> Term<'a> {
let mut tmp = HashMap::new(); let mut tmp = HashMap::new();
for (k, v) in pairs { for (k, v) in pairs {
match v.0 { tmp.insert(k.as_bytes(), v.0);
T::Err(e) => return Term(T::Err(e)),
vv => {
tmp.insert(k.as_bytes(), vv);
}
}
} }
Term(T::Dict(tmp)) Term(T::Dict(tmp))
} }
@ -123,7 +118,7 @@ pub fn dict<'a, I: IntoIterator<Item = (&'a str, Term<'a>)>>(pairs: I) -> Term<'
/// ``` /// ```
/// use nettext::enc::*; /// use nettext::enc::*;
/// ///
/// assert_eq!(encode(bytes(b"hello, world!")).unwrap(), b"aGVsbG8sIHdvcmxkIQ"); /// assert_eq!(encode(bytes(b"hello, world!")), b"aGVsbG8sIHdvcmxkIQ");
/// ``` /// ```
pub fn bytes(bytes: &[u8]) -> Term<'static> { pub fn bytes(bytes: &[u8]) -> Term<'static> {
Term(T::OwnedStr( Term(T::OwnedStr(
@ -152,31 +147,24 @@ impl<'a> Term<'a> {
/// Transforms the initial term into a list if necessary. /// Transforms the initial term into a list if necessary.
#[must_use] #[must_use]
pub fn append(self, t: Term<'a>) -> Term<'a> { pub fn append(self, t: Term<'a>) -> Term<'a> {
match t.0 { match self.0 {
T::Err(e) => Term(T::Err(e)),
tt => match self.0 {
T::List(mut v) => { T::List(mut v) => {
v.push(tt); v.push(t.0);
Term(T::List(v)) Term(T::List(v))
} }
x => Term(T::List(vec![x, tt])), x => Term(T::List(vec![x, t.0])),
},
} }
} }
/// Inserts a key-value pair into a term that is a dictionnary. /// Inserts a key-value pair into a term that is a dictionnary.
/// Fails if `self` is not a dictionnary. /// Fails if `self` is not a dictionnary.
#[must_use] pub fn insert(self, k: &'a str, v: Term<'a>) -> Result<'a> {
pub fn insert(self, k: &'a str, v: Term<'a>) -> Term<'a> { match self.0 {
match v.0 {
T::Err(e) => Term(T::Err(e)),
vv => match self.0 {
T::Dict(mut d) => { T::Dict(mut d) => {
d.insert(k.as_bytes(), vv); d.insert(k.as_bytes(), v.0);
Term(T::Dict(d)) Ok(Term(T::Dict(d)))
} }
_ => Term(T::Err(Error::NotADictionnary)), _ => Err(Error::NotADictionnary),
},
} }
} }
@ -188,7 +176,7 @@ impl<'a> Term<'a> {
/// ``` /// ```
/// use nettext::enc::*; /// use nettext::enc::*;
/// ///
/// assert_eq!(encode(list([string("hello"), string("world")]).nested()).unwrap(), b"{ . = hello world }"); /// assert_eq!(encode(list([string("hello").unwrap(), string("world").unwrap()]).unwrap().nested()), b"{ . = hello world }");
/// ``` /// ```
#[must_use] #[must_use]
pub fn nested(self) -> Term<'a> { pub fn nested(self) -> Term<'a> {
@ -199,18 +187,13 @@ impl<'a> Term<'a> {
// ---- encoding function ---- // ---- encoding function ----
/// Generate the nettext representation of a term /// Generate the nettext representation of a term
pub fn encode(t: Term<'_>) -> Result<Vec<u8>, Error> { pub fn encode(t: Term<'_>) -> Vec<u8> {
let mut buf = Vec::with_capacity(128); let mut buf = Vec::with_capacity(128);
encode_aux(&mut buf, t.0, 0, true)?; encode_aux(&mut buf, t.0, 0, true);
Ok(buf) buf
} }
fn encode_aux( fn encode_aux(buf: &mut Vec<u8>, term: T<'_>, indent: usize, is_toplevel: bool) {
buf: &mut Vec<u8>,
term: T<'_>,
indent: usize,
is_toplevel: bool,
) -> Result<(), Error> {
match term { match term {
T::Str(s) => buf.extend_from_slice(s), T::Str(s) => buf.extend_from_slice(s),
T::OwnedStr(s) => buf.extend_from_slice(&s), T::OwnedStr(s) => buf.extend_from_slice(&s),
@ -222,7 +205,7 @@ fn encode_aux(
let (k, v) = d.into_iter().next().unwrap(); let (k, v) = d.into_iter().next().unwrap();
buf.extend_from_slice(k); buf.extend_from_slice(k);
buf.extend_from_slice(b" = "); buf.extend_from_slice(b" = ");
encode_aux(buf, v, indent + 2, false)?; encode_aux(buf, v, indent + 2, false);
buf.extend_from_slice(b" }"); buf.extend_from_slice(b" }");
} else { } else {
buf.extend_from_slice(b"{\n"); buf.extend_from_slice(b"{\n");
@ -236,7 +219,7 @@ fn encode_aux(
} }
buf.extend_from_slice(k); buf.extend_from_slice(k);
buf.extend_from_slice(b" = "); buf.extend_from_slice(b" = ");
encode_aux(buf, v, indent2, false)?; encode_aux(buf, v, indent2, false);
buf.extend_from_slice(b",\n"); buf.extend_from_slice(b",\n");
} }
for _ in 0..indent { for _ in 0..indent {
@ -256,12 +239,10 @@ fn encode_aux(
} else if i > 0 { } else if i > 0 {
buf.push(b' '); buf.push(b' ');
} }
encode_aux(buf, v, indent2, is_toplevel)?; encode_aux(buf, v, indent2, is_toplevel);
} }
} }
T::Err(e) => return Err(e),
} }
Ok(())
} }
#[cfg(test)] #[cfg(test)]
@ -271,20 +252,21 @@ mod tests {
#[test] #[test]
fn complex1() { fn complex1() {
let input = list([ let input = list([
string("HELLO"), string("HELLO").unwrap(),
string("alexhelloworld"), string("alexhelloworld").unwrap(),
dict([ dict([
("from", string("jxx")), ("from", string("jxx").unwrap()),
("subject", string("hello")), ("subject", string("hello").unwrap()),
("data", raw(b"{ f1 = plop, f2 = kuko }")), ("data", raw(b"{ f1 = plop, f2 = kuko }").unwrap()),
]), ]),
]); ])
.unwrap();
let expected = b"HELLO alexhelloworld { let expected = b"HELLO alexhelloworld {
data = { f1 = plop, f2 = kuko }, data = { f1 = plop, f2 = kuko },
from = jxx, from = jxx,
subject = hello, subject = hello,
}"; }";
let enc = encode(input).unwrap(); let enc = encode(input);
eprintln!("{}", std::str::from_utf8(&enc).unwrap()); eprintln!("{}", std::str::from_utf8(&enc).unwrap());
eprintln!("{}", std::str::from_utf8(&expected[..]).unwrap()); eprintln!("{}", std::str::from_utf8(&expected[..]).unwrap());
assert_eq!(&enc, &expected[..]); assert_eq!(&enc, &expected[..]);
@ -292,20 +274,24 @@ mod tests {
#[test] #[test]
fn nested() { fn nested() {
assert!(encode(list([ assert!(list([
string("a"), string("a").unwrap(),
string("b"), string("b").unwrap(),
list([string("c"), string("d")]) list([string("c").unwrap(), string("d").unwrap()]).unwrap()
])) ])
.is_err()); .is_err());
assert_eq!( assert_eq!(
encode(list([ encode(
string("a"), list([
string("b"), string("a").unwrap(),
list([string("c"), string("d")]).nested() string("b").unwrap(),
])) list([string("c").unwrap(), string("d").unwrap()])
.unwrap(), .unwrap()
.nested()
])
.unwrap()
),
b"a b { . = c d }" b"a b { . = c d }"
); );
} }

View file

@ -9,16 +9,16 @@
//! //!
//! // Encode a fist object that represents a payload that will be hashed and signed //! // Encode a fist object that represents a payload that will be hashed and signed
//! let text1 = encode(list([ //! let text1 = encode(list([
//! string("CALL"), //! string("CALL").unwrap(),
//! string("myfunction"), //! string("myfunction").unwrap(),
//! dict([ //! dict([
//! ("a", string("hello")), //! ("a", string("hello").unwrap()),
//! ("b", string("world")), //! ("b", string("world").unwrap()),
//! ("c", raw(b"{ a = 12, b = 42 }")), //! ("c", raw(b"{ a = 12, b = 42 }").unwrap()),
//! ("d", bytes_split(&((0..128u8).collect::<Vec<_>>()))), //! ("d", bytes_split(&((0..128u8).collect::<Vec<_>>()))),
//! ]), //! ]),
//! keypair.public.term(), //! keypair.public.term().unwrap(),
//! ])).unwrap(); //! ]).unwrap());
//! eprintln!("{}", std::str::from_utf8(&text1).unwrap()); //! eprintln!("{}", std::str::from_utf8(&text1).unwrap());
//! //!
//! let hash = crypto::Blake2Sum::compute(&text1); //! let hash = crypto::Blake2Sum::compute(&text1);
@ -26,10 +26,10 @@
//! //!
//! // Encode a second object that represents the signed and hashed payload //! // Encode a second object that represents the signed and hashed payload
//! let text2 = encode(dict([ //! let text2 = encode(dict([
//! ("hash", hash.term()), //! ("hash", hash.term().unwrap()),
//! ("signature", sign.term()), //! ("signature", sign.term().unwrap()),
//! ("payload", raw(&text1)), //! ("payload", raw(&text1).unwrap()),
//! ])).unwrap(); //! ]));
//! eprintln!("{}", std::str::from_utf8(&text2).unwrap()); //! eprintln!("{}", std::str::from_utf8(&text2).unwrap());
//! //!
//! // Decode and check everything is fine //! // Decode and check everything is fine