use std::time::{Duration, Instant}; use anyhow::{anyhow, bail, Result}; use rand::prelude::*; use serde::{Deserialize, Serialize}; use tokio::io::AsyncReadExt; use k2v_client::{BatchReadOp, Filter, K2vClient, K2vValue}; use rusoto_core::HttpClient; use rusoto_credential::{AwsCredentials, StaticProvider}; use rusoto_s3::{GetObjectRequest, ListObjectsV2Request, S3Client, S3}; use rusoto_signature::Region; use crate::cryptoblob::*; use crate::time::now_msec; const SAVE_STATE_EVERY: usize = 64; // Checkpointing interval constants: a checkpoint is not made earlier // than CHECKPOINT_INTERVAL time after the last one, and is not made // if there are less than CHECKPOINT_MIN_OPS new operations since last one. const CHECKPOINT_INTERVAL: Duration = Duration::from_secs(3600); const CHECKPOINT_MIN_OPS: usize = 16; pub trait BayouState: Default + Clone + Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static { type Op: Clone + Serialize + for<'de> Deserialize<'de> + std::fmt::Debug + Send + Sync + 'static; fn apply(&self, op: &Self::Op) -> Self; } pub struct Bayou { bucket: String, path: String, key: Key, k2v: K2vClient, s3: S3Client, checkpoint: (Timestamp, S), history: Vec<(Timestamp, S::Op, Option)>, last_sync: Option, } impl Bayou { pub fn new( creds: AwsCredentials, k2v_region: Region, s3_region: Region, bucket: String, path: String, key: Key, ) -> Result { let k2v_client = K2vClient::new(k2v_region, bucket.clone(), creds.clone(), None)?; let static_creds = StaticProvider::new( creds.aws_access_key_id().to_string(), creds.aws_secret_access_key().to_string(), creds.token().clone(), None, ); let s3_client = S3Client::new_with(HttpClient::new()?, static_creds, s3_region); Ok(Self { bucket, path, key, k2v: k2v_client, s3: s3_client, checkpoint: (Timestamp::zero(), S::default()), history: vec![], last_sync: None, }) } /// Re-reads the state from persistent storage backend pub async fn sync(&mut self) -> Result<()> { // 1. List checkpoints let prefix = format!("{}/checkpoint/", self.path); let mut lor = ListObjectsV2Request::default(); lor.bucket = self.bucket.clone(); lor.max_keys = Some(1000); lor.prefix = Some(prefix.clone()); let checkpoints_res = self.s3.list_objects_v2(lor).await?; let mut checkpoints = vec![]; for object in checkpoints_res.contents.unwrap_or_default() { if let Some(key) = object.key { if let Some(ckid) = key.strip_prefix(&prefix) { if let Some(ts) = Timestamp::parse(ckid) { checkpoints.push((ts, key)); } } } } checkpoints.sort_by_key(|(ts, _)| *ts); eprintln!("(sync) listed checkpoints: {:?}", checkpoints); // 2. Load last checkpoint if different from currently used one let checkpoint = if let Some((ts, key)) = checkpoints.last() { if *ts == self.checkpoint.0 { (*ts, None) } else { eprintln!("(sync) loading checkpoint: {}", key); let mut gor = GetObjectRequest::default(); gor.bucket = self.bucket.clone(); gor.key = key.to_string(); let obj_res = self.s3.get_object(gor).await?; let obj_body = obj_res.body.ok_or(anyhow!("Missing object body"))?; let mut buf = Vec::with_capacity(obj_res.content_length.unwrap_or(128) as usize); obj_body.into_async_read().read_to_end(&mut buf).await?; let ck = open_deserialize::(&buf, &self.key)?; (*ts, Some(ck)) } } else { (Timestamp::zero(), None) }; if self.checkpoint.0 > checkpoint.0 { bail!("Existing checkpoint is more recent than stored one"); } if let Some(ck) = checkpoint.1 { eprintln!( "(sync) updating checkpoint to loaded state at {:?}", checkpoint.0 ); self.checkpoint = (checkpoint.0, ck); }; // remove from history events before checkpoint self.history = std::mem::take(&mut self.history) .into_iter() .skip_while(|(ts, _, _)| *ts < self.checkpoint.0) .collect(); // 3. List all operations starting from checkpoint let ts_ser = self.checkpoint.0.serialize(); eprintln!("(sync) looking up operations starting at {}", ts_ser); let ops_map = self .k2v .read_batch(&[BatchReadOp { partition_key: &self.path, filter: Filter { start: Some(&ts_ser), end: None, prefix: None, limit: None, reverse: false, }, single_item: false, conflicts_only: false, include_tombstones: false, }]) .await? .into_iter() .next() .ok_or(anyhow!("Missing K2V result"))? .items; let mut ops = vec![]; for (tsstr, val) in ops_map { let ts = Timestamp::parse(&tsstr) .ok_or(anyhow!("Invalid operation timestamp: {}", tsstr))?; if val.value.len() != 1 { bail!("Invalid operation, has {} values", val.value.len()); } match &val.value[0] { K2vValue::Value(v) => { let op = open_deserialize::(&v, &self.key)?; eprintln!("(sync) operation {}: {} {:?}", tsstr, base64::encode(v), op); ops.push((ts, op)); } K2vValue::Tombstone => { unreachable!(); } } } ops.sort_by_key(|(ts, _)| *ts); eprintln!("(sync) {} operations", ops.len()); // if no operations, clean up and return now if ops.is_empty() { self.history.clear(); return Ok(()); } // 4. Check that first operation has same timestamp as checkpoint (if not zero) if self.checkpoint.0 != Timestamp::zero() && ops[0].0 != self.checkpoint.0 { bail!( "First operation in listing doesn't have timestamp that corresponds to checkpoint" ); } // 5. Apply all operations in order // Hypothesis: before the loaded checkpoint, operations haven't changed // between what's on storage and what we used to calculate the state in RAM here. let i0 = self .history .iter() .enumerate() .zip(ops.iter()) .skip_while(|((i, (ts1, _, _)), (ts2, _))| ts1 == ts2) .map(|((i, _), _)| i) .next() .unwrap_or(self.history.len()); if ops.len() > i0 { // Remove operations from first position where histories differ self.history.truncate(i0); // Look up last calculated state which we have saved and start from there. let mut last_state = (0, &self.checkpoint.1); for (i, (_, _, state_opt)) in self.history.iter().enumerate().rev() { if let Some(state) = state_opt { last_state = (i + 1, state); break; } } // Calculate state at the end of this common part of the history let mut state = last_state.1.clone(); for (_, op, _) in self.history[last_state.0..].iter() { state = state.apply(op); } // Now, apply all operations retrieved from storage after the common part for (ts, op) in ops.drain(i0..) { state = state.apply(&op); if (self.history.len() + 1) % SAVE_STATE_EVERY == 0 { self.history.push((ts, op, Some(state.clone()))); } else { self.history.push((ts, op, None)); } } // Always save final state as result of last operation self.history.last_mut().unwrap().2 = Some(state); } self.last_sync = Some(Instant::now()); Ok(()) } async fn check_recent_sync(&mut self) -> Result<()> { match self.last_sync { Some(t) if (Instant::now() - t) < CHECKPOINT_INTERVAL / 10 => Ok(()), _ => self.sync().await, } } /// Applies a new operation on the state. Once this function returns, /// the option has been safely persisted to storage backend pub async fn push(&mut self, op: S::Op) -> Result<()> { self.check_recent_sync().await?; let ts = Timestamp::after( self.history .last() .map(|(ts, _, _)| ts) .unwrap_or(&self.checkpoint.0), ); self.k2v .insert_item( &self.path, &ts.serialize(), seal_serialize(&op, &self.key)?, None, ) .await?; let new_state = self.state().apply(&op); self.history.push((ts, op, Some(new_state))); // Clear previously saved state in history if not required let hlen = self.history.len(); if hlen >= 2 && (hlen - 1) % SAVE_STATE_EVERY != 0 { self.history[hlen - 2].2 = None; } self.checkpoint().await?; Ok(()) } /// Save a new checkpoint if previous checkpoint is too old pub async fn checkpoint(&mut self) -> Result<()> { self.check_recent_sync().await?; eprintln!("Mock checkpointing, not implemented"); Ok(()) } pub fn state(&self) -> &S { if let Some(last) = self.history.last() { last.2.as_ref().unwrap() } else { &self.checkpoint.1 } } } #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)] pub struct Timestamp { pub msec: u64, pub rand: u64, } impl Timestamp { pub fn now() -> Self { let mut rng = thread_rng(); Self { msec: now_msec(), rand: rng.gen::(), } } pub fn after(other: &Self) -> Self { let mut rng = thread_rng(); Self { msec: std::cmp::max(now_msec(), other.msec + 1), rand: rng.gen::(), } } pub fn zero() -> Self { Self { msec: 0, rand: 0 } } pub fn serialize(&self) -> String { let mut bytes = [0u8; 16]; bytes[0..8].copy_from_slice(&u64::to_be_bytes(self.msec)); bytes[8..16].copy_from_slice(&u64::to_be_bytes(self.rand)); hex::encode(&bytes) } pub fn parse(v: &str) -> Option { let bytes = hex::decode(v).ok()?; if bytes.len() != 16 { return None; } Some(Self { msec: u64::from_be_bytes(bytes[0..8].try_into().unwrap()), rand: u64::from_be_bytes(bytes[8..16].try_into().unwrap()), }) } }