nettext/src/buf/mod.rs
2023-05-10 18:30:16 +02:00

403 lines
12 KiB
Rust

pub mod decode;
use std::borrow::Cow;
use crate::is_string_char;
pub use decode::*;
pub const STR_INLINE_MAX: usize = 18;
pub type Pos = u32;
#[derive(Clone, Copy, Debug)]
pub struct ITerm(Pos);
#[derive(Clone, Copy, Debug)]
pub(crate) struct IRaw {
start: Pos,
end: Pos,
}
#[derive(Clone, Copy, Debug)]
pub(crate) struct ISeq {
seq_start: Pos,
seq_end: Pos,
}
#[derive(Clone, Copy, Debug)]
pub(crate) struct IDict {
dict_start: Pos,
dict_end: Pos,
}
#[derive(Debug)]
pub(crate) enum TTerm {
Str(IRaw),
StrInline(u8, [u8; STR_INLINE_MAX]),
RawSeq(IRaw, ISeq),
RawList(IRaw, ISeq),
RawDict(IRaw, IDict),
Seq(ISeq),
List(ISeq),
Dict(IDict),
}
#[derive(Debug)]
pub struct Buf<'a> {
bytes: Cow<'a, [u8]>,
seqs: Vec<ITerm>,
dicts: Vec<(IRaw, ITerm)>,
terms: Vec<TTerm>,
}
#[derive(Debug)]
pub enum TermError {
InvalidIndex(ITerm),
WrongType(&'static str, &'static str),
WrongLength(usize, usize),
WrongKeys,
NoRawRepresentation,
}
#[derive(Debug)]
pub enum ValueError {
InvalidIndex(ITerm),
DuplicateKey,
BadString,
SeqInSeq,
}
impl<'a> Buf<'a> {
pub fn new() -> Self {
Self {
bytes: Default::default(),
seqs: Vec::with_capacity(16),
dicts: Vec::with_capacity(16),
terms: Vec::with_capacity(16),
}
}
// ================ READING FUNCTIONS ==================
pub fn raw(&self, term: ITerm) -> Result<&[u8], TermError> {
match self.get_term(term)? {
TTerm::StrInline(len, bytes) => Ok(&bytes[..*len as usize]),
TTerm::Str(r) | TTerm::RawSeq(r, _) | TTerm::RawList(r, _) | TTerm::RawDict(r, _) => {
Ok(self.get_bytes(*r))
}
_ => Err(TermError::NoRawRepresentation),
}
}
pub fn str(&self, term: ITerm) -> Result<&str, TermError> {
match self.get_term(term)? {
TTerm::StrInline(len, bytes) => {
let bytes = &bytes[..*len as usize];
let s = unsafe { std::str::from_utf8_unchecked(bytes) };
Ok(s)
}
TTerm::Str(r) => {
let bytes = self.get_bytes(*r);
let s = unsafe { std::str::from_utf8_unchecked(bytes) };
Ok(s)
}
t => Err(TermError::WrongType("string", t.typename())),
}
}
pub fn seq<'x>(&'x self, term: &'x ITerm) -> Result<&'x [ITerm], TermError> {
match self.get_term(*term)? {
TTerm::RawSeq(_, s) | TTerm::Seq(s) => {
Ok(&self.seqs[s.seq_start as usize..s.seq_end as usize])
}
_ => Ok(std::slice::from_ref(term)),
}
}
pub fn seq_of<const N: usize>(&self, term: ITerm) -> Result<[ITerm; N], TermError> {
match self.get_term(term)? {
TTerm::RawSeq(_, s) | TTerm::Seq(s) => {
let seq_len = (s.seq_end - s.seq_start) as usize;
if seq_len == N {
let seq = &self.seqs[s.seq_start as usize..s.seq_end as usize];
Ok(seq.try_into().unwrap())
} else {
Err(TermError::WrongLength(N, seq_len))
}
}
t => Err(TermError::WrongType("seq", t.typename())),
}
}
pub fn list(&self, term: ITerm) -> Result<&[ITerm], TermError> {
match self.get_term(term)? {
TTerm::RawList(_, s) | TTerm::List(s) => {
Ok(&self.seqs[s.seq_start as usize..s.seq_end as usize])
}
t => Err(TermError::WrongType("list", t.typename())),
}
}
pub fn list_of<const N: usize>(&self, term: ITerm) -> Result<[ITerm; N], TermError> {
match self.get_term(term)? {
TTerm::RawList(_, s) | TTerm::List(s) => {
let list_len = (s.seq_end - s.seq_start) as usize;
if list_len == N {
let seq = &self.seqs[s.seq_start as usize..s.seq_end as usize];
Ok(seq.try_into().unwrap())
} else {
Err(TermError::WrongLength(N, list_len))
}
}
t => Err(TermError::WrongType("list", t.typename())),
}
}
pub fn dict_get(&self, term: ITerm, key: &str) -> Result<Option<ITerm>, TermError> {
match self.get_term(term)? {
TTerm::RawDict(_, d) | TTerm::Dict(d) => {
let dict = &self.dicts[d.dict_start as usize..d.dict_end as usize];
let pos_opt = dict
.binary_search_by(|(k, _)| self.get_bytes(*k).cmp(key.as_bytes()))
.ok();
Ok(pos_opt.map(|pos| dict[pos].1))
}
t => Err(TermError::WrongType("dict", t.typename())),
}
}
pub fn dict_of<const N: usize>(
&self,
term: ITerm,
keys: [&str; N],
allow_other: bool,
) -> Result<[ITerm; N], TermError> {
match self.get_term(term)? {
TTerm::RawDict(_, d) | TTerm::Dict(d) => {
let dict = &self.dicts[d.dict_start as usize..d.dict_end as usize];
if dict.len() < N || (dict.len() > N && !allow_other) {
return Err(TermError::WrongKeys);
}
let mut ret = [ITerm(0); N];
for i in 0..N {
let pos = dict
.binary_search_by(|(k, _)| self.get_bytes(*k).cmp(keys[i].as_bytes()))
.map_err(|_| TermError::WrongKeys)?;
ret[i] = dict[pos].1;
}
Ok(ret)
}
t => Err(TermError::WrongType("dict", t.typename())),
}
}
pub fn dict_iter(
&self,
term: ITerm,
) -> Result<impl Iterator<Item = (&str, ITerm)> + '_, TermError> {
match self.get_term(term)? {
TTerm::RawDict(_, d) | TTerm::Dict(d) => {
let dict = &self.dicts[d.dict_start as usize..d.dict_end as usize];
let iter = dict.iter().map(|(k, v)| {
(
unsafe { std::str::from_utf8_unchecked(self.get_bytes(*k)) },
*v,
)
});
Ok(iter)
}
t => Err(TermError::WrongType("dict", t.typename())),
}
}
// ================= WRITING FUNCTIONS ================
pub fn push_str(&mut self, s: &str) -> Result<ITerm, ValueError> {
let b = s.as_bytes();
if !b.iter().copied().all(is_string_char) {
return Err(ValueError::BadString);
}
let term = if b.len() <= STR_INLINE_MAX {
let mut bytes = [0u8; STR_INLINE_MAX];
bytes[..b.len()].copy_from_slice(b);
TTerm::StrInline(b.len() as u8, bytes)
} else {
TTerm::Str(self.push_bytes(b))
};
Ok(self.push_term(term))
}
pub fn push_seq(&mut self, iterator: impl Iterator<Item = ITerm>) -> Result<ITerm, ValueError> {
let seq_start = self.seqs.len();
for term in iterator {
match self.terms.get(term.0 as usize) {
None => {
self.seqs.truncate(seq_start);
return Err(ValueError::InvalidIndex(term));
}
Some(TTerm::RawSeq(_, _)) => {
self.seqs.truncate(seq_start);
return Err(ValueError::SeqInSeq);
}
_ => {
self.seqs.push(term);
}
}
}
let seq = ISeq {
seq_start: seq_start as Pos,
seq_end: self.seqs.len() as Pos,
};
Ok(self.push_term(TTerm::Seq(seq)))
}
pub fn push_list(
&mut self,
iterator: impl Iterator<Item = ITerm>,
) -> Result<ITerm, ValueError> {
let list_start = self.seqs.len();
for term in iterator {
match self.terms.get(term.0 as usize) {
None => {
self.seqs.truncate(list_start);
return Err(ValueError::InvalidIndex(term));
}
_ => {
self.seqs.push(term);
}
}
}
let list = ISeq {
seq_start: list_start as Pos,
seq_end: self.seqs.len() as Pos,
};
Ok(self.push_term(TTerm::List(list)))
}
pub fn push_dict<'k>(
&mut self,
iterator: impl Iterator<Item = (&'k str, ITerm)>,
) -> Result<ITerm, ValueError> {
let bytes_start = self.bytes.len();
let dict_start = self.dicts.len();
for (key, term) in iterator {
if !key.as_bytes().iter().copied().all(is_string_char) {
return Err(ValueError::BadString);
}
let key = self.push_bytes(key.as_bytes());
match self.terms.get(term.0 as usize) {
None => {
self.bytes.to_mut().truncate(bytes_start);
self.dicts.truncate(dict_start);
return Err(ValueError::InvalidIndex(term));
}
_ => {
self.dicts.push((key, term));
}
}
}
self.dicts[dict_start..]
.sort_by_key(|(k, _)| (&self.bytes[k.start as usize..k.end as usize], k.start));
for ((k1, _), (k2, _)) in self.dicts[dict_start..]
.iter()
.zip(self.dicts[dict_start + 1..].iter())
{
if self.get_bytes(*k1) == self.get_bytes(*k2) {
self.bytes.to_mut().truncate(bytes_start);
self.dicts.truncate(dict_start);
return Err(ValueError::DuplicateKey);
}
}
let dict = IDict {
dict_start: dict_start as Pos,
dict_end: self.dicts.len() as Pos,
};
Ok(self.push_term(TTerm::Dict(dict)))
}
pub fn push_raw(&mut self, raw: &[u8]) -> Result<ITerm, ParseError> {
let bytes_len = self.bytes.len();
let seqs_len = self.seqs.len();
let dicts_len = self.dicts.len();
let terms_len = self.terms.len();
let raw = self.push_bytes(raw);
let result = self.decode(raw);
if result.is_err() {
// reset to initial state
self.bytes.to_mut().truncate(bytes_len);
self.seqs.truncate(seqs_len);
self.dicts.truncate(dicts_len);
self.terms.truncate(terms_len);
}
result
}
// ==== Internal ====
#[inline]
fn get_term(&self, term: ITerm) -> Result<&TTerm, TermError> {
self.terms
.get(term.0 as usize)
.ok_or(TermError::InvalidIndex(term))
}
#[inline]
fn push_term(&mut self, term: TTerm) -> ITerm {
let ret = ITerm(self.terms.len() as Pos);
self.terms.push(term);
ret
}
#[inline]
fn push_bytes(&mut self, raw: &[u8]) -> IRaw {
let bytes_start = self.bytes.len();
self.bytes.to_mut().extend(raw);
IRaw {
start: bytes_start as Pos,
end: self.bytes.len() as Pos,
}
}
}
impl TTerm {
fn typename(&self) -> &'static str {
match self {
TTerm::Str(_) | TTerm::StrInline(_, _) => "string",
TTerm::RawSeq(_, _) | TTerm::Seq(_) => "seq",
TTerm::RawList(_, _) | TTerm::List(_) => "list",
TTerm::RawDict(_, _) | TTerm::Dict(_) => "dict",
}
}
}
#[cfg(test)]
mod tests {
pub use super::*;
#[test]
fn test_sizeof() {
assert_eq!(std::mem::size_of::<TTerm>(), 20);
}
}