nettext/src/dec/mod.rs

611 lines
20 KiB
Rust

//! Functions to decode nettext and helpers to map it to data structures
mod decode;
mod error;
use std::collections::HashMap;
#[cfg(any(feature = "dryoc"))]
use crate::crypto;
use crate::debug;
pub use decode::*;
pub use error::*;
/// A parsed nettext term, with many helpers for destructuring
///
/// Lifetime 'a is the lifetime of the buffer containing the encoded data.
///
/// Lifetime 'b is the lifetime of another Term from which this one is borrowed, when it
/// is returned by one of the helper functions, or 'static when first returned from
/// `decode()`
#[derive(Eq, PartialEq, Debug)]
pub struct Term<'a, 'b>(pub(crate) AnyTerm<'a, 'b>);
#[derive(Eq, PartialEq, Clone)]
pub(crate) enum AnyTerm<'a, 'b> {
Str(&'a [u8]),
Dict(&'a [u8], HashMap<&'a [u8], AnyTerm<'a, 'b>>),
DictRef(&'a [u8], &'b HashMap<&'a [u8], AnyTerm<'a, 'b>>),
List(&'a [u8], Vec<AnyTerm<'a, 'b>>),
ListRef(&'a [u8], &'b [AnyTerm<'a, 'b>]),
Seq(&'a [u8], Vec<NonSeqTerm<'a, 'b>>),
SeqRef(&'a [u8], &'b [NonSeqTerm<'a, 'b>]),
}
#[derive(Eq, PartialEq, Clone)]
pub(crate) enum NonSeqTerm<'a, 'b> {
Str(&'a [u8]),
Dict(&'a [u8], HashMap<&'a [u8], AnyTerm<'a, 'b>>),
DictRef(&'a [u8], &'b HashMap<&'a [u8], AnyTerm<'a, 'b>>),
List(&'a [u8], Vec<AnyTerm<'a, 'b>>),
ListRef(&'a [u8], &'b [AnyTerm<'a, 'b>]),
}
impl<'a, 'b> From<NonSeqTerm<'a, 'b>> for AnyTerm<'a, 'b> {
fn from(x: NonSeqTerm<'a, 'b>) -> AnyTerm<'a, 'b> {
match x {
NonSeqTerm::Str(s) => AnyTerm::Str(s),
NonSeqTerm::Dict(raw, d) => AnyTerm::Dict(raw, d),
NonSeqTerm::DictRef(raw, d) => AnyTerm::DictRef(raw, d),
NonSeqTerm::List(raw, l) => AnyTerm::List(raw, l),
NonSeqTerm::ListRef(raw, l) => AnyTerm::ListRef(raw, l),
}
}
}
impl<'a, 'b> TryFrom<AnyTerm<'a, 'b>> for NonSeqTerm<'a, 'b> {
type Error = ();
fn try_from(x: AnyTerm<'a, 'b>) -> Result<NonSeqTerm<'a, 'b>, ()> {
match x {
AnyTerm::Str(s) => Ok(NonSeqTerm::Str(s)),
AnyTerm::Dict(raw, d) => Ok(NonSeqTerm::Dict(raw, d)),
AnyTerm::DictRef(raw, d) => Ok(NonSeqTerm::DictRef(raw, d)),
AnyTerm::List(raw, l) => Ok(NonSeqTerm::List(raw, l)),
AnyTerm::ListRef(raw, l) => Ok(NonSeqTerm::ListRef(raw, l)),
_ => Err(()),
}
}
}
impl<'a> From<AnyTerm<'a, 'static>> for Term<'a, 'static> {
fn from(x: AnyTerm<'a, 'static>) -> Term<'a, 'static> {
Term(x)
}
}
// ---- PUBLIC IMPLS ----
impl<'a, 'b> Term<'a, 'b> {
// ---- STRUCTURAL MAPPINGS ----
/// Get the term's raw representation
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
///
/// let term = decode(b"hello world").unwrap();
/// assert_eq!(term.raw(), b"hello world");
/// ```
pub fn raw(&self) -> &'a [u8] {
self.0.raw()
}
/// Get the term's raw representation as an str
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
///
/// let term = decode(b"hello { a = x, b = y }").unwrap();
/// assert_eq!(term.raw_str().unwrap(), "hello { a = x, b = y }");
/// ```
pub fn raw_str(&self) -> Result<&'a str, TypeError> {
Ok(std::str::from_utf8(self.0.raw())?)
}
/// If the term is a single string, get that string
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
///
/// let term1 = decode(b"hello").unwrap();
/// assert_eq!(term1.str().unwrap(), "hello");
///
/// let term2 = decode(b"hello world").unwrap();
/// assert!(term2.str().is_err());
/// ```
pub fn str(&self) -> Result<&'a str, TypeError> {
match &self.0 {
AnyTerm::Str(s) => Ok(std::str::from_utf8(s)?),
_ => Err(TypeError::WrongType("STR")),
}
}
/// If the term is a single string, or a sequence containing only strings,
/// get its raw representation
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
///
/// let term1 = decode(b"hello world").unwrap();
/// assert_eq!(term1.string().unwrap(), "hello world");
///
/// let term2 = decode(b"hello { a= 5}").unwrap();
/// assert!(term2.string().is_err());
/// ```
pub fn string(&self) -> Result<&'a str, TypeError> {
match &self.0 {
AnyTerm::Str(s) => Ok(std::str::from_utf8(s)?),
AnyTerm::Seq(r, l) if l.iter().all(|x| matches!(x, NonSeqTerm::Str(_))) => {
Ok(std::str::from_utf8(r)?)
}
_ => Err(TypeError::WrongType("STRING")),
}
}
/// Return a sequence of terms made from this term.
/// If it is a str or a dict, returns a seq of a single term.
/// If it is a sequence, that's the seq of terms we return.
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
///
/// let term1 = decode(b"hello").unwrap();
/// let seq1 = term1.seq();
/// assert_eq!(seq1.len(), 1);
/// assert_eq!(seq1[0].str().unwrap(), "hello");
///
/// let term2 = decode(b"hello world").unwrap();
/// let seq2 = term2.seq();
/// assert_eq!(seq2.len(), 2);
/// assert_eq!(seq2[0].str().unwrap(), "hello");
/// assert_eq!(seq2[1].str().unwrap(), "world");
/// ```
pub fn seq(&self) -> Vec<Term<'a, '_>> {
match self.0.mkref() {
AnyTerm::SeqRef(_r, l) => l.iter().map(|x| Term(x.mkref().into())).collect::<Vec<_>>(),
x => vec![Term(x)],
}
}
/// Same as `.seq()`, but deconstructs it in a const length array,
/// dynamically checking if there are the correct number of items.
/// This allows to directly bind the resulting seq into discrete variables.
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
///
/// let term1 = decode(b"hello").unwrap();
/// let [s1] = term1.seq_of().unwrap();
/// assert_eq!(s1.str().unwrap(), "hello");
///
/// let term2 = decode(b"hello world").unwrap();
/// let [s2a, s2b] = term2.seq_of().unwrap();
/// assert_eq!(s2a.str().unwrap(), "hello");
/// assert_eq!(s2b.str().unwrap(), "world");
/// ```
pub fn seq_of<const N: usize>(&self) -> Result<[Term<'a, '_>; N], TypeError> {
let seq = self.seq();
let seq_len = seq.len();
seq.try_into()
.map_err(|_| TypeError::WrongLength(seq_len, N))
}
/// Same as `.seq_of()`, but only binds the first N-1 terms.
/// If there are exactly N terms, the last one is bound to the Nth return variable.
/// If there are more then N terms, the remaining terms are bound to a new seq term
/// that is returned as the Nth return variable.
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
///
/// let term1 = decode(b"hello world").unwrap();
/// let [s1a, s1b] = term1.seq_of_first().unwrap();
/// assert_eq!(s1a.str().unwrap(), "hello");
/// assert_eq!(s1b.str().unwrap(), "world");
///
/// let term2 = decode(b"hello mighty world").unwrap();
/// let [s2a, s2b] = term2.seq_of_first().unwrap();
/// assert_eq!(s2a.str().unwrap(), "hello");
/// assert_eq!(s2b.seq().len(), 2);
/// assert_eq!(s2b.raw(), b"mighty world");
/// ```
pub fn seq_of_first<const N: usize>(&self) -> Result<[Term<'a, '_>; N], TypeError> {
match self.0.mkref() {
AnyTerm::SeqRef(raw, seq) => match seq.len().cmp(&N) {
std::cmp::Ordering::Less => Err(TypeError::WrongLength(seq.len(), N)),
std::cmp::Ordering::Equal => Ok(seq
.iter()
.map(|x| Term(x.mkref().into()))
.collect::<Vec<_>>()
.try_into()
.unwrap()),
std::cmp::Ordering::Greater => {
let mut ret = Vec::with_capacity(N);
for item in seq[0..N - 1].iter() {
ret.push(Term(item.mkref().into()));
}
let remaining_begin = seq[N - 1].raw().as_ptr() as usize;
let remaining_offset = remaining_begin - raw.as_ptr() as usize;
let remaining_raw = &raw[remaining_offset..];
ret.push(Term(AnyTerm::SeqRef(remaining_raw, &seq[N - 1..])));
Ok(ret.try_into().unwrap())
}
},
x if N == 1 => Ok([Term(x)]
.into_iter()
.collect::<Vec<_>>()
.try_into()
.unwrap()),
_ => Err(TypeError::WrongLength(1, N)),
}
}
/// Checks term is a dictionnary and returns hashmap of inner terms.
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
///
/// let term = decode(b"{ k1 = v1, k2 = v2 }").unwrap();
/// let dict = term.dict().unwrap();
/// assert_eq!(dict.get("k1").unwrap().str().unwrap(), "v1");
/// assert_eq!(dict.get("k2").unwrap().str().unwrap(), "v2");
/// ```
pub fn dict(&self) -> Result<HashMap<&'a str, Term<'a, '_>>, TypeError> {
match self.0.mkref() {
AnyTerm::DictRef(_, d) => {
let mut res = HashMap::with_capacity(d.len());
for (k, t) in d.iter() {
res.insert(std::str::from_utf8(k)?, Term(t.mkref()));
}
Ok(res)
}
_ => Err(TypeError::WrongType("DICT")),
}
}
/// Checks term is a dictionnary whose keys are exactly those supplied,
/// and returns the associated values as a seq.
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
///
/// let term = decode(b"{ k1 = v1, k2 = v2, k3 = v3 }").unwrap();
/// let [s1, s2] = term.dict_of(["k1", "k2"], true).unwrap();
/// assert_eq!(s1.str().unwrap(), "v1");
/// assert_eq!(s2.str().unwrap(), "v2");
///
/// assert!(term.dict_of(["k1", "k2"], false).is_err());
/// ```
pub fn dict_of<const N: usize, T: AsRef<[u8]>>(
&self,
keys: [T; N],
allow_extra_keys: bool,
) -> Result<[Term<'a, '_>; N], TypeError> {
match self.0.mkref() {
AnyTerm::DictRef(_, dict) => {
// Check all required keys exist in dictionnary
for k in keys.iter() {
if !dict.contains_key(k.as_ref()) {
return Err(TypeError::MissingKey(debug(k.as_ref()).to_string()));
}
}
if !allow_extra_keys {
// Check that dictionnary contains no extraneous keys
for k in dict.keys() {
if !keys.iter().any(|k2| k2.as_ref() == *k) {
return Err(TypeError::UnexpectedKey(debug(k).to_string()));
}
}
}
Ok(keys.map(|k| Term(dict.get(k.as_ref()).unwrap().mkref())))
}
_ => Err(TypeError::WrongType("DICT")),
}
}
/// Checks term is a dictionnary whose keys are included in those supplied,
/// and returns the associated values as a seq of options.
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
///
/// let term = decode(b"{ k1 = v1, k2 = v2, k4 = v4 }").unwrap();
/// let [s1, s2, s3] = term.dict_of_opt(["k1", "k2", "k3"], true).unwrap();
/// assert_eq!(s1.unwrap().str().unwrap(), "v1");
/// assert_eq!(s2.unwrap().str().unwrap(), "v2");
/// assert!(s3.is_none());
///
/// assert!(term.dict_of_opt(["k1", "k2", "k3"], false).is_err());
/// ```
pub fn dict_of_opt<const N: usize, T: AsRef<[u8]>>(
&self,
keys: [T; N],
allow_extra_keys: bool,
) -> Result<[Option<Term<'a, '_>>; N], TypeError> {
match self.0.mkref() {
AnyTerm::DictRef(_, dict) => {
if !allow_extra_keys {
// Check that dictionnary contains no extraneous keys
for k in dict.keys() {
if !keys.iter().any(|x| x.as_ref() == *k) {
return Err(TypeError::UnexpectedKey(debug(k).to_string()));
}
}
}
Ok(keys.map(|k| dict.get(k.as_ref()).map(|x| Term(x.mkref()))))
}
_ => Err(TypeError::WrongType("DICT")),
}
}
/// Checks if the term is a list, and if so, return its elements in a vec.
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
///
/// let term2 = decode(b"[ hello, world ]").unwrap();
/// let seq2 = term2.list().unwrap();
/// assert_eq!(seq2.len(), 2);
/// assert_eq!(seq2[0].str().unwrap(), "hello");
/// assert_eq!(seq2[1].str().unwrap(), "world");
/// ```
pub fn list(&self) -> Result<Vec<Term<'a, '_>>, TypeError> {
match self.0.mkref() {
AnyTerm::ListRef(_r, l) => Ok(l.iter().map(|x| Term(x.mkref().into())).collect::<Vec<_>>()),
_ => Err(TypeError::WrongType("LIST")),
}
}
// ---- TYPE CASTS ----
/// Try to interpret this str as an i64
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
///
/// let term = decode(b"42").unwrap();
/// assert_eq!(term.int().unwrap(), 42);
/// ```
pub fn int(&self) -> Result<i64, TypeError> {
self.str()?
.parse::<i64>()
.map_err(|_| TypeError::WrongType("INT"))
}
/// Try to interpret this string as base64-encoded bytes (uses URL-safe, no-padding encoding)
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
///
/// let term = decode(b"aGVsbG8sIHdvcmxkIQ").unwrap();
/// assert_eq!(term.bytes().unwrap(), b"hello, world!");
/// ```
pub fn bytes(&self) -> Result<Vec<u8>, TypeError> {
let decode = |encoded| {
base64::decode_config(encoded, base64::URL_SAFE_NO_PAD)
.map_err(|_| TypeError::WrongType("BYTES"))
};
match self.0.mkref() {
AnyTerm::Str(encoded) => {
if encoded == b"." {
Ok(vec![])
} else {
decode(encoded)
}
}
AnyTerm::SeqRef(_, seq) => {
let mut ret = Vec::with_capacity(128);
for term in seq.iter() {
if let NonSeqTerm::Str(encoded) = term {
ret.extend(decode(encoded)?)
} else {
return Err(TypeError::WrongType("BYTES"));
}
}
Ok(ret)
}
_ => Err(TypeError::WrongType("BYTES")),
}
}
/// Try to interpret this string as base64-encoded bytes,
/// with an exact length.
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
///
/// let term = decode(b"aGVsbG8sIHdvcmxkIQ").unwrap();
/// assert_eq!(&term.bytes_exact::<13>().unwrap(), b"hello, world!");
/// ```
pub fn bytes_exact<const N: usize>(&self) -> Result<[u8; N], TypeError> {
let bytes = self.bytes()?;
let bytes_len = bytes.len();
bytes
.try_into()
.map_err(|_| TypeError::WrongLength(bytes_len, N))
}
}
// ---- CRYPTO HELPERS ----
#[cfg(feature = "dryoc")]
impl<'a, 'b> Term<'a, 'b> {
/// Try to interpret this string as a Blake2b512 digest (32-bytes base64 encoded)
///
/// Example:
///
/// ```
/// use nettext::dec::decode;
/// use nettext::crypto::generichash::GenericHash;
///
/// let term = decode(b"{
/// message = hello,
/// hash = Mk3PAn3UowqTLEQfNlol6GsXPe-kuOWJSCU0cbgbcs8,
/// }").unwrap();
/// let [msg, hash] = term.dict_of(["message", "hash"], false).unwrap();
/// let expected_hash = GenericHash::hash_with_defaults(msg.raw(), None::<&Vec<u8>>).unwrap();
/// assert_eq!(hash.hash().unwrap(), expected_hash);
/// ```
pub fn hash(&self) -> Result<crypto::generichash::Hash, TypeError> {
Ok(crypto::generichash::Hash::from(self.bytes_exact()?))
}
/// Try to interpret this string as an ed25519 keypair (64 bytes base64 encoded)
pub fn keypair(&self) -> Result<crypto::SigningKeyPair, TypeError> {
let secret_key = crypto::sign::SecretKey::from(self.bytes_exact()?);
Ok(crypto::SigningKeyPair::from_secret_key(secret_key))
}
/// Try to interpret this string as an ed25519 public key (32 bytes base64 encoded)
pub fn public_key(&self) -> Result<crypto::sign::PublicKey, TypeError> {
Ok(crypto::sign::PublicKey::from(self.bytes_exact()?))
}
/// Try to interpret this string as an ed25519 secret key (32 bytes base64 encoded)
pub fn secret_key(&self) -> Result<crypto::sign::SecretKey, TypeError> {
Ok(crypto::sign::SecretKey::from(self.bytes_exact()?))
}
/// Try to interpret this string as an ed25519 signature (64 bytes base64 encoded)
pub fn signature(&self) -> Result<crypto::sign::Signature, TypeError> {
Ok(crypto::sign::Signature::from(self.bytes_exact()?))
}
}
// ---- INTERNAL IMPLS ----
impl<'a, 'b> AnyTerm<'a, 'b> {
fn raw(&self) -> &'a [u8] {
match self {
AnyTerm::Str(s) => s,
AnyTerm::Dict(r, _)
| AnyTerm::DictRef(r, _)
| AnyTerm::List(r, _)
| AnyTerm::ListRef(r, _)
| AnyTerm::Seq(r, _)
| AnyTerm::SeqRef(r, _) => r,
}
}
pub(crate) fn mkref(&self) -> AnyTerm<'a, '_> {
match &self {
AnyTerm::Str(s) => AnyTerm::Str(s),
AnyTerm::Dict(r, d) => AnyTerm::DictRef(r, d),
AnyTerm::DictRef(r, d) => AnyTerm::DictRef(r, d),
AnyTerm::List(r, l) => AnyTerm::ListRef(r, l),
AnyTerm::ListRef(r, l) => AnyTerm::ListRef(r, l),
AnyTerm::Seq(r, l) => AnyTerm::SeqRef(r, &l[..]),
AnyTerm::SeqRef(r, l) => AnyTerm::SeqRef(r, l),
}
}
}
impl<'a, 'b> NonSeqTerm<'a, 'b> {
fn raw(&self) -> &'a [u8] {
match &self {
NonSeqTerm::Str(s) => s,
NonSeqTerm::Dict(r, _) | NonSeqTerm::DictRef(r, _) => r,
NonSeqTerm::List(r, _) | NonSeqTerm::ListRef(r, _) => r,
}
}
fn mkref(&self) -> NonSeqTerm<'a, '_> {
match &self {
NonSeqTerm::Str(s) => NonSeqTerm::Str(s),
NonSeqTerm::Dict(r, d) => NonSeqTerm::DictRef(r, d),
NonSeqTerm::DictRef(r, d) => NonSeqTerm::DictRef(r, d),
NonSeqTerm::List(r, l) => NonSeqTerm::ListRef(r, l),
NonSeqTerm::ListRef(r, l) => NonSeqTerm::ListRef(r, l),
}
}
}
// ---- DISPLAY REPR = Raw nettext representation ----
impl<'a, 'b> std::fmt::Display for AnyTerm<'a, 'b> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
write!(
f,
"{}",
std::str::from_utf8(self.raw()).map_err(|_| Default::default())?
)
}
}
impl<'a, 'b> std::fmt::Display for Term<'a, 'b> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
write!(f, "{}", self.0)
}
}
// ---- DEBUG REPR ----
impl<'a, 'b> std::fmt::Debug for AnyTerm<'a, 'b> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
match self.mkref() {
AnyTerm::Str(s) => write!(f, "Str(`{}`)", debug(s)),
AnyTerm::DictRef(raw, d) => {
write!(f, "Dict<`{}`", debug(raw))?;
for (k, v) in d.iter() {
write!(f, "\n `{}`={:?}", debug(k), v)?;
}
write!(f, ">")
}
AnyTerm::ListRef(raw, l) => {
write!(f, "List[`{}`", debug(raw))?;
for i in l.iter() {
write!(f, "\n {:?}", i)?;
}
write!(f, "]")
}
AnyTerm::SeqRef(raw, l) => {
write!(f, "Seq[`{}`", debug(raw))?;
for i in l.iter() {
write!(f, "\n {:?}", i)?;
}
write!(f, "]")
}
_ => unreachable!(),
}
}
}
impl<'a, 'b> std::fmt::Debug for NonSeqTerm<'a, 'b> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
AnyTerm::from(self.mkref()).fmt(f)
}
}