templar_common/versioned_state/
core.rs

1use std::io::{Error, ErrorKind};
2
3use borsh::{BorshDeserialize, BorshSerialize};
4use near_sdk::env;
5
6const VERSION_KEY: &[u8] = b"__v";
7
8pub fn write_state_version(version: u32) {
9    env::storage_write(VERSION_KEY, &version.to_le_bytes());
10}
11
12pub fn read_state_version() -> Result<u32, std::io::Error> {
13    let Some(bytes) = env::storage_read(VERSION_KEY) else {
14        return Ok(0);
15    };
16
17    borsh::from_slice(&bytes)
18}
19
20#[derive(Debug)]
21#[near_sdk::near(serializers = [borsh])]
22pub struct VersionedState<T: StateVersion>(T);
23
24impl<T: StateVersion> VersionedState<T> {
25    pub fn new(state: T) -> Self {
26        write_state_version(T::VERSION);
27        Self(state)
28    }
29
30    pub fn version(&self) -> u32 {
31        T::VERSION
32    }
33}
34
35impl<T: StateVersion> std::ops::Deref for VersionedState<T> {
36    type Target = T;
37
38    fn deref(&self) -> &Self::Target {
39        &self.0
40    }
41}
42
43impl<T: StateVersion> std::ops::DerefMut for VersionedState<T> {
44    fn deref_mut(&mut self) -> &mut Self::Target {
45        &mut self.0
46    }
47}
48
49pub trait StateVersion {
50    const VERSION: u32;
51    type NewArgs;
52
53    fn new(args: Self::NewArgs) -> VersionedState<Self>
54    where
55        Self: Sized;
56
57    fn needs_migration() -> Result<bool, std::io::Error> {
58        let stored = read_state_version()?;
59        if stored > Self::VERSION {
60            return Err(Error::new(
61                ErrorKind::InvalidData,
62                format!(
63                    "Stored state version {stored} is newer than supported version {}",
64                    Self::VERSION
65                ),
66            ));
67        }
68
69        Ok(stored < Self::VERSION)
70    }
71}
72
73pub trait StateTransformer {
74    type Input: StateVersion + BorshDeserialize;
75    type Output: StateVersion + BorshSerialize;
76    type Error;
77
78    fn input_version(&self) -> u32 {
79        Self::Input::VERSION
80    }
81
82    fn output_version(&self) -> u32 {
83        Self::Output::VERSION
84    }
85
86    fn run(&self) -> Result<Self::Output, MigrationError<Self::Error>> {
87        let stored = read_state_version()?;
88        let expected = self.input_version();
89        if stored != expected {
90            return Err(MigrationError::StoredVersionMismatch { stored, expected });
91        }
92        let old_state =
93            env::state_read::<Self::Input>().ok_or(MigrationError::FailedToDeserializeOldState)?;
94        let new_state = self
95            .transform(old_state)
96            .map_err(MigrationError::Transformation)?;
97        env::state_write(&new_state);
98        write_state_version(self.output_version());
99        Ok(new_state)
100    }
101
102    fn transform(&self, input: Self::Input) -> Result<Self::Output, Self::Error>;
103}
104
105#[derive(thiserror::Error, Debug)]
106pub enum MigrationError<E> {
107    #[error("Failed to deserialize stored state version: {0}")]
108    StoredVersionDeserialization(#[from] std::io::Error),
109    #[error("Stored state version {stored} != args `from_version` {expected}")]
110    StoredVersionMismatch { stored: u32, expected: u32 },
111    #[error("Failed to deserialize old state")]
112    FailedToDeserializeOldState,
113    #[error("Failed to transform old state")]
114    Transformation(E),
115}
116
117pub trait Migrator {
118    fn run(self);
119}
120
121#[cfg(test)]
122mod tests {
123    use near_sdk::{test_utils::VMContextBuilder, testing_env};
124
125    use super::*;
126
127    fn context() {
128        testing_env!(VMContextBuilder::new().build());
129    }
130
131    #[test]
132    fn stored_version_defaults_to_zero() {
133        context();
134        assert_eq!(read_state_version().unwrap(), 0);
135    }
136
137    #[test]
138    fn malformed_stored_version_errors() {
139        context();
140        write_state_version(7);
141        env::storage_write(VERSION_KEY, &[1, 2, 3]);
142
143        assert!(read_state_version().is_err());
144    }
145
146    #[test]
147    fn future_stored_version_errors() {
148        context();
149        write_state_version(9);
150
151        let error = TestState::needs_migration().unwrap_err();
152        assert_eq!(error.kind(), ErrorKind::InvalidData);
153        assert!(error
154            .to_string()
155            .contains("Stored state version 9 is newer"));
156    }
157
158    struct TestState;
159
160    impl StateVersion for TestState {
161        const VERSION: u32 = 2;
162        type NewArgs = ();
163
164        fn new((): Self::NewArgs) -> VersionedState<Self> {
165            VersionedState::new(Self)
166        }
167    }
168}