pub mod decode; use std::borrow::Cow; use crate::is_string_char; pub use decode::*; pub const STR_INLINE_MAX: usize = 18; pub type Pos = u32; #[derive(Clone, Copy, Debug)] pub struct ITerm(Pos); #[derive(Clone, Copy, Debug)] pub(crate) struct IRaw { start: Pos, end: Pos, } #[derive(Clone, Copy, Debug)] pub(crate) struct ISeq { seq_start: Pos, seq_end: Pos, } #[derive(Clone, Copy, Debug)] pub(crate) struct IDict { dict_start: Pos, dict_end: Pos, } #[derive(Debug)] pub(crate) enum TTerm { Str(IRaw), StrInline(u8, [u8; STR_INLINE_MAX]), RawSeq(IRaw, ISeq), RawList(IRaw, ISeq), RawDict(IRaw, IDict), Seq(ISeq), List(ISeq), Dict(IDict), } #[derive(Debug)] pub struct Buf<'a> { bytes: Cow<'a, [u8]>, seqs: Vec, dicts: Vec<(IRaw, ITerm)>, terms: Vec, } #[derive(Debug)] pub enum TermError { InvalidIndex(ITerm), WrongType(&'static str, &'static str), WrongLength(usize, usize), WrongKeys, NoRawRepresentation, } #[derive(Debug)] pub enum ValueError { InvalidIndex(ITerm), DuplicateKey, BadString, SeqInSeq, } impl<'a> Buf<'a> { pub fn new() -> Self { Self { bytes: Default::default(), seqs: Vec::with_capacity(16), dicts: Vec::with_capacity(16), terms: Vec::with_capacity(16), } } // ================ READING FUNCTIONS ================== pub fn raw(&self, term: ITerm) -> Result<&[u8], TermError> { match self.get_term(term)? { TTerm::StrInline(len, bytes) => Ok(&bytes[..*len as usize]), TTerm::Str(r) | TTerm::RawSeq(r, _) | TTerm::RawList(r, _) | TTerm::RawDict(r, _) => { Ok(self.get_bytes(*r)) } _ => Err(TermError::NoRawRepresentation), } } pub fn str(&self, term: ITerm) -> Result<&str, TermError> { match self.get_term(term)? { TTerm::StrInline(len, bytes) => { let bytes = &bytes[..*len as usize]; let s = unsafe { std::str::from_utf8_unchecked(bytes) }; Ok(s) } TTerm::Str(r) => { let bytes = self.get_bytes(*r); let s = unsafe { std::str::from_utf8_unchecked(bytes) }; Ok(s) } t => Err(TermError::WrongType("string", t.typename())), } } pub fn seq<'x>(&'x self, term: &'x ITerm) -> Result<&'x [ITerm], TermError> { match self.get_term(*term)? { TTerm::RawSeq(_, s) | TTerm::Seq(s) => { Ok(&self.seqs[s.seq_start as usize..s.seq_end as usize]) } _ => Ok(std::slice::from_ref(term)), } } pub fn seq_of(&self, term: ITerm) -> Result<[ITerm; N], TermError> { match self.get_term(term)? { TTerm::RawSeq(_, s) | TTerm::Seq(s) => { let seq_len = (s.seq_end - s.seq_start) as usize; if seq_len == N { let seq = &self.seqs[s.seq_start as usize..s.seq_end as usize]; Ok(seq.try_into().unwrap()) } else { Err(TermError::WrongLength(N, seq_len)) } } t => Err(TermError::WrongType("seq", t.typename())), } } pub fn list(&self, term: ITerm) -> Result<&[ITerm], TermError> { match self.get_term(term)? { TTerm::RawList(_, s) | TTerm::List(s) => { Ok(&self.seqs[s.seq_start as usize..s.seq_end as usize]) } t => Err(TermError::WrongType("list", t.typename())), } } pub fn list_of(&self, term: ITerm) -> Result<[ITerm; N], TermError> { match self.get_term(term)? { TTerm::RawList(_, s) | TTerm::List(s) => { let list_len = (s.seq_end - s.seq_start) as usize; if list_len == N { let seq = &self.seqs[s.seq_start as usize..s.seq_end as usize]; Ok(seq.try_into().unwrap()) } else { Err(TermError::WrongLength(N, list_len)) } } t => Err(TermError::WrongType("list", t.typename())), } } pub fn dict_get(&self, term: ITerm, key: &str) -> Result, TermError> { match self.get_term(term)? { TTerm::RawDict(_, d) | TTerm::Dict(d) => { let dict = &self.dicts[d.dict_start as usize..d.dict_end as usize]; let pos_opt = dict .binary_search_by(|(k, _)| self.get_bytes(*k).cmp(key.as_bytes())) .ok(); Ok(pos_opt.map(|pos| dict[pos].1)) } t => Err(TermError::WrongType("dict", t.typename())), } } pub fn dict_of( &self, term: ITerm, keys: [&str; N], allow_other: bool, ) -> Result<[ITerm; N], TermError> { match self.get_term(term)? { TTerm::RawDict(_, d) | TTerm::Dict(d) => { let dict = &self.dicts[d.dict_start as usize..d.dict_end as usize]; if dict.len() < N || (dict.len() > N && !allow_other) { return Err(TermError::WrongKeys); } let mut ret = [ITerm(0); N]; for i in 0..N { let pos = dict .binary_search_by(|(k, _)| self.get_bytes(*k).cmp(keys[i].as_bytes())) .map_err(|_| TermError::WrongKeys)?; ret[i] = dict[pos].1; } Ok(ret) } t => Err(TermError::WrongType("dict", t.typename())), } } pub fn dict_iter( &self, term: ITerm, ) -> Result + '_, TermError> { match self.get_term(term)? { TTerm::RawDict(_, d) | TTerm::Dict(d) => { let dict = &self.dicts[d.dict_start as usize..d.dict_end as usize]; let iter = dict.iter().map(|(k, v)| { ( unsafe { std::str::from_utf8_unchecked(self.get_bytes(*k)) }, *v, ) }); Ok(iter) } t => Err(TermError::WrongType("dict", t.typename())), } } // ================= WRITING FUNCTIONS ================ pub fn push_str(&mut self, s: &str) -> Result { let b = s.as_bytes(); if !b.iter().copied().all(is_string_char) { return Err(ValueError::BadString); } let term = if b.len() <= STR_INLINE_MAX { let mut bytes = [0u8; STR_INLINE_MAX]; bytes[..b.len()].copy_from_slice(b); TTerm::StrInline(b.len() as u8, bytes) } else { TTerm::Str(self.push_bytes(b)) }; Ok(self.push_term(term)) } pub fn push_seq(&mut self, iterator: impl Iterator) -> Result { let seq_start = self.seqs.len(); for term in iterator { match self.terms.get(term.0 as usize) { None => { self.seqs.truncate(seq_start); return Err(ValueError::InvalidIndex(term)); } Some(TTerm::RawSeq(_, _)) => { self.seqs.truncate(seq_start); return Err(ValueError::SeqInSeq); } _ => { self.seqs.push(term); } } } let seq = ISeq { seq_start: seq_start as Pos, seq_end: self.seqs.len() as Pos, }; Ok(self.push_term(TTerm::Seq(seq))) } pub fn push_list( &mut self, iterator: impl Iterator, ) -> Result { let list_start = self.seqs.len(); for term in iterator { match self.terms.get(term.0 as usize) { None => { self.seqs.truncate(list_start); return Err(ValueError::InvalidIndex(term)); } _ => { self.seqs.push(term); } } } let list = ISeq { seq_start: list_start as Pos, seq_end: self.seqs.len() as Pos, }; Ok(self.push_term(TTerm::List(list))) } pub fn push_dict<'k>( &mut self, iterator: impl Iterator, ) -> Result { let bytes_start = self.bytes.len(); let dict_start = self.dicts.len(); for (key, term) in iterator { if !key.as_bytes().iter().copied().all(is_string_char) { return Err(ValueError::BadString); } let key = self.push_bytes(key.as_bytes()); match self.terms.get(term.0 as usize) { None => { self.bytes.to_mut().truncate(bytes_start); self.dicts.truncate(dict_start); return Err(ValueError::InvalidIndex(term)); } _ => { self.dicts.push((key, term)); } } } self.dicts[dict_start..] .sort_by_key(|(k, _)| (&self.bytes[k.start as usize..k.end as usize], k.start)); for ((k1, _), (k2, _)) in self.dicts[dict_start..] .iter() .zip(self.dicts[dict_start + 1..].iter()) { if self.get_bytes(*k1) == self.get_bytes(*k2) { self.bytes.to_mut().truncate(bytes_start); self.dicts.truncate(dict_start); return Err(ValueError::DuplicateKey); } } let dict = IDict { dict_start: dict_start as Pos, dict_end: self.dicts.len() as Pos, }; Ok(self.push_term(TTerm::Dict(dict))) } pub fn push_raw(&mut self, raw: &[u8]) -> Result { let bytes_len = self.bytes.len(); let seqs_len = self.seqs.len(); let dicts_len = self.dicts.len(); let terms_len = self.terms.len(); let raw = self.push_bytes(raw); let result = self.decode(raw); if result.is_err() { // reset to initial state self.bytes.to_mut().truncate(bytes_len); self.seqs.truncate(seqs_len); self.dicts.truncate(dicts_len); self.terms.truncate(terms_len); } result } // ==== Internal ==== #[inline] fn get_term(&self, term: ITerm) -> Result<&TTerm, TermError> { self.terms .get(term.0 as usize) .ok_or(TermError::InvalidIndex(term)) } #[inline] fn push_term(&mut self, term: TTerm) -> ITerm { let ret = ITerm(self.terms.len() as Pos); self.terms.push(term); ret } #[inline] fn push_bytes(&mut self, raw: &[u8]) -> IRaw { let bytes_start = self.bytes.len(); self.bytes.to_mut().extend(raw); IRaw { start: bytes_start as Pos, end: self.bytes.len() as Pos, } } } impl TTerm { fn typename(&self) -> &'static str { match self { TTerm::Str(_) | TTerm::StrInline(_, _) => "string", TTerm::RawSeq(_, _) | TTerm::Seq(_) => "seq", TTerm::RawList(_, _) | TTerm::List(_) => "list", TTerm::RawDict(_, _) | TTerm::Dict(_) => "dict", } } } #[cfg(test)] mod tests { pub use super::*; #[test] fn test_sizeof() { assert_eq!(std::mem::size_of::(), 20); } }