garage/src/util/migrate.rs

160 lines
3.8 KiB
Rust
Raw Normal View History

2023-01-03 13:44:47 +00:00
use serde::{Deserialize, Serialize};
2023-01-03 14:53:13 +00:00
/// Indicates that this type has an encoding that can be migrated from
/// a previous version upon upgrades of Garage.
2023-01-03 13:44:47 +00:00
pub trait Migrate: Serialize + for<'de> Deserialize<'de> + 'static {
/// A sequence of bytes to add at the beginning of the serialized
/// string, to identify that the data is of this version.
const VERSION_MARKER: &'static [u8] = b"";
2023-01-03 13:44:47 +00:00
/// The previous version of this data type, from which items of this version
2023-01-03 14:53:13 +00:00
/// can be migrated.
2023-01-03 13:44:47 +00:00
type Previous: Migrate;
2023-01-03 14:53:13 +00:00
/// The migration function that transforms a value decoded in the old format
/// to an up-to-date value.
2023-01-03 13:44:47 +00:00
fn migrate(previous: Self::Previous) -> Self;
2023-01-03 14:53:13 +00:00
/// Decode an encoded version of this type, going through a migration if necessary.
2023-01-03 13:44:47 +00:00
fn decode(bytes: &[u8]) -> Option<Self> {
let marker_len = Self::VERSION_MARKER.len();
if bytes.len() >= marker_len && &bytes[..marker_len] == Self::VERSION_MARKER {
if let Ok(value) = rmp_serde::decode::from_read_ref::<_, Self>(&bytes[marker_len..]) {
2023-01-03 13:44:47 +00:00
return Some(value);
}
}
Self::Previous::decode(bytes).map(Self::migrate)
}
2023-01-03 14:53:13 +00:00
/// Encode this type with optionnal version marker
2023-01-03 13:44:47 +00:00
fn encode(&self) -> Result<Vec<u8>, rmp_serde::encode::Error> {
let mut wr = Vec::with_capacity(128);
wr.extend_from_slice(Self::VERSION_MARKER);
2023-01-03 13:44:47 +00:00
let mut se = rmp_serde::Serializer::new(&mut wr)
.with_struct_map()
.with_string_variants();
self.serialize(&mut se)?;
Ok(wr)
}
}
2023-01-03 14:53:13 +00:00
/// Indicates that this type has no previous encoding version to be migrated from.
2023-01-03 13:44:47 +00:00
pub trait InitialFormat: Serialize + for<'de> Deserialize<'de> + 'static {
/// A sequence of bytes to add at the beginning of the serialized
/// string, to identify that the data is of this version.
const VERSION_MARKER: &'static [u8] = b"";
2023-01-03 13:44:47 +00:00
}
impl<T: InitialFormat> Migrate for T {
const VERSION_MARKER: &'static [u8] = <T as InitialFormat>::VERSION_MARKER;
2023-01-03 13:44:47 +00:00
type Previous = NoPrevious;
fn migrate(_previous: Self::Previous) -> Self {
unreachable!();
}
}
2023-01-03 14:53:13 +00:00
/// Internal type used by InitialFormat, not meant for general use.
2023-01-03 13:44:47 +00:00
#[derive(Serialize, Deserialize)]
pub struct NoPrevious;
impl Migrate for NoPrevious {
type Previous = NoPrevious;
fn migrate(_previous: Self::Previous) -> Self {
unreachable!();
}
fn decode(_bytes: &[u8]) -> Option<Self> {
None
}
fn encode(&self) -> Result<Vec<u8>, rmp_serde::encode::Error> {
unreachable!()
}
}
2023-01-03 14:53:13 +00:00
#[cfg(test)]
mod test {
use super::*;
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
struct V1 {
a: usize,
b: String,
}
impl InitialFormat for V1 {}
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
struct V2 {
a: usize,
b: Vec<String>,
c: String,
}
impl Migrate for V2 {
const VERSION_MARKER: &'static [u8] = b"GtestV2";
type Previous = V1;
fn migrate(prev: V1) -> V2 {
V2 {
a: prev.a,
b: vec![prev.b],
c: String::new(),
}
}
}
#[test]
fn test_v1() {
let x = V1 {
a: 12,
b: "hello".into(),
};
let x_enc = x.encode().unwrap();
let y = V1::decode(&x_enc).unwrap();
assert_eq!(x, y);
}
#[test]
fn test_v2() {
let x = V2 {
a: 12,
b: vec!["hello".into(), "world".into()],
c: "plop".into(),
};
let x_enc = x.encode().unwrap();
assert_eq!(&x_enc[..V2::VERSION_MARKER.len()], V2::VERSION_MARKER);
let y = V2::decode(&x_enc).unwrap();
assert_eq!(x, y);
}
#[test]
fn test_migrate() {
let x = V1 {
a: 12,
b: "hello".into(),
};
let x_enc = x.encode().unwrap();
let xx = V1::decode(&x_enc).unwrap();
assert_eq!(x, xx);
let y = V2::decode(&x_enc).unwrap();
assert_eq!(
y,
V2 {
a: 12,
b: vec!["hello".into()],
c: "".into(),
}
);
let y_enc = y.encode().unwrap();
assert_eq!(&y_enc[..V2::VERSION_MARKER.len()], V2::VERSION_MARKER);
let z = V2::decode(&y_enc).unwrap();
assert_eq!(y, z);
}
}