diff --git a/src/model/block.rs b/src/model/block.rs index 090ffbc22..699ff32d3 100644 --- a/src/model/block.rs +++ b/src/model/block.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; use tokio::fs; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::{watch, Mutex, Notify}; -use zstd::stream::{decode_all as zstd_decode, encode_all as zstd_encode}; +use zstd::stream::{decode_all as zstd_decode, Encoder}; use garage_util::data::*; use garage_util::error::Error; @@ -43,7 +43,10 @@ pub enum Message { GetBlock(Hash), /// Message to send a block of data, either because requested, of for first delivery of new /// block - PutBlock { hash: Hash, data: BlockData }, + PutBlock { + hash: Hash, + data: BlockData, + }, /// Ask other node if they should have this block, but don't actually have it NeedBlockQuery(Hash), /// Response : whether the node do require that block @@ -64,6 +67,13 @@ impl BlockData { BlockData::Compressed(_) => true, } } + + pub fn buffer(&self) -> &Vec { + match self { + BlockData::Plain(b) => b, + BlockData::Compressed(b) => b, + } + } } impl RpcMessage for Message {} @@ -164,6 +174,10 @@ impl BlockManager { /// Write a block to disk pub async fn write_block(&self, hash: &Hash, data: &BlockData) -> Result { + let mut path = self.block_dir(hash); + + let _lock = self.data_dir_lock.lock().await; + let clean_plain = match self.is_block_compressed(hash).await { Ok(true) => return Ok(Message::Ok), Ok(false) if !data.is_compressed() => return Ok(Message::Ok), // we have a plain block, and the provided block is not compressed either @@ -171,29 +185,17 @@ impl BlockManager { Err(_) => false, }; - let mut path = self.block_dir(hash); - - let (buffer, checksum) = match data { - BlockData::Plain(b) => (b, None), - BlockData::Compressed(b) => { - let checksum = blake2sum(&b); - (b, Some(checksum)) - } - }; - - let _lock = self.data_dir_lock.lock().await; - fs::create_dir_all(&path).await?; path.push(hex::encode(hash)); - if checksum.is_some() { - path.set_extension("zst_b2"); + + if data.is_compressed() { + path.set_extension("zst"); } + let buffer = data.buffer(); + let mut f = fs::File::create(path.clone()).await?; f.write_all(&buffer).await?; - if let Some(checksum) = checksum { - f.write_all(checksum.as_slice()).await?; - } if clean_plain { path.set_extension(""); @@ -215,7 +217,7 @@ impl BlockManager { f.map(|f| (f, false)).map_err(Into::into) } Ok(true) => { - path.set_extension("zst_b2"); + path.set_extension("zst"); let f = fs::File::open(&path).await; f.map(|f| (f, true)).map_err(Into::into) } @@ -233,14 +235,7 @@ impl BlockManager { drop(f); let sum_ok = if compressed { - if data.len() >= 32 { - let data_len = data.len() - 32; - let checksum = data.split_off(data_len); - blake2sum(&data[..]).as_slice() == &checksum - } else { - // the file is too short to be valid - false - } + zstd_check_checksum(&data[..]) } else { blake2sum(&data[..]) == *hash }; @@ -287,7 +282,7 @@ impl BlockManager { async fn is_block_compressed(&self, hash: &Hash) -> Result { let mut path = self.block_path(hash); - path.set_extension("zst_b2"); + path.set_extension("zst"); if fs::metadata(&path).await.is_ok() { return Ok(true); } @@ -634,3 +629,16 @@ fn u64_from_be_bytes>(bytes: T) -> u64 { x8.copy_from_slice(bytes.as_ref()); u64::from_be_bytes(x8) } + +fn zstd_check_checksum(source: R) -> bool { + zstd::stream::copy_decode(source, std::io::sink()).is_ok() +} + +fn zstd_encode(mut source: R, level: i32) -> std::io::Result> { + let mut result = Vec::::new(); + let mut encoder = Encoder::new(&mut result, level)?; + encoder.include_checksum(true)?; + std::io::copy(&mut source, &mut encoder)?; + encoder.finish()?; + Ok(result) +} diff --git a/src/util/config.rs b/src/util/config.rs index ab9cae6c8..29901d462 100644 --- a/src/util/config.rs +++ b/src/util/config.rs @@ -45,7 +45,7 @@ pub struct Config { #[serde(default = "default_replication_factor")] pub data_replication_factor: usize, - /// Zstd compression level used on data blocks + /// Zstd compression level used on data blocks #[serde(default)] pub compression_level: i32,