templar_common/
versioned_state.rs

1use std::io::{Error, ErrorKind};
2
3use borsh::{BorshDeserialize, BorshSerialize};
4use near_sdk::{env, ext_contract, near};
5
6const VERSION_KEY: &[u8] = b"__v";
7
8/// Writes the state version to storage.
9pub fn write_state_version(version: u32) {
10    env::storage_write(VERSION_KEY, &version.to_le_bytes());
11}
12
13/// Reads the state version from storage.
14pub fn read_state_version() -> Result<u32, std::io::Error> {
15    let Some(bytes) = env::storage_read(VERSION_KEY) else {
16        return Ok(0);
17    };
18
19    borsh::from_slice(&bytes)
20}
21
22#[derive(Debug)]
23#[near(serializers = [borsh])]
24pub struct VersionedState<T: StateVersion>(T);
25
26impl<T: StateVersion> VersionedState<T> {
27    pub fn new(state: T) -> Self {
28        write_state_version(T::VERSION);
29        Self(state)
30    }
31
32    pub fn version(&self) -> u32 {
33        T::VERSION
34    }
35}
36
37impl<T: StateVersion> std::ops::Deref for VersionedState<T> {
38    type Target = T;
39
40    fn deref(&self) -> &Self::Target {
41        &self.0
42    }
43}
44
45impl<T: StateVersion> std::ops::DerefMut for VersionedState<T> {
46    fn deref_mut(&mut self) -> &mut Self::Target {
47        &mut self.0
48    }
49}
50
51pub trait StateVersion {
52    const VERSION: u32;
53    type NewArgs;
54
55    fn new(args: Self::NewArgs) -> VersionedState<Self>
56    where
57        Self: Sized;
58
59    fn needs_migration() -> Result<bool, std::io::Error> {
60        let stored = read_state_version()?;
61        if stored > Self::VERSION {
62            return Err(Error::new(
63                ErrorKind::InvalidData,
64                format!(
65                    "Stored state version {stored} is newer than supported version {}",
66                    Self::VERSION
67                ),
68            ));
69        }
70
71        Ok(stored < Self::VERSION)
72    }
73}
74
75pub trait StateTransformer {
76    type Input: StateVersion + BorshDeserialize;
77    type Output: StateVersion + BorshSerialize;
78    type Error;
79
80    fn input_version(&self) -> u32 {
81        Self::Input::VERSION
82    }
83
84    fn output_version(&self) -> u32 {
85        Self::Output::VERSION
86    }
87
88    fn run(&self) -> Result<Self::Output, MigrationError<Self::Error>> {
89        let stored = read_state_version()?;
90        let expected = self.input_version();
91        if stored != expected {
92            return Err(MigrationError::StoredVersionMismatch { stored, expected });
93        }
94        let old_state =
95            env::state_read::<Self::Input>().ok_or(MigrationError::FailedToDeserializeOldState)?;
96        let new_state = self
97            .transform(old_state)
98            .map_err(MigrationError::Transformation)?;
99        env::state_write(&new_state);
100        write_state_version(self.output_version());
101        Ok(new_state)
102    }
103
104    fn transform(&self, input: Self::Input) -> Result<Self::Output, Self::Error>;
105}
106
107#[derive(thiserror::Error, Debug)]
108pub enum MigrationError<E> {
109    #[error("Failed to deserialize stored state version: {0}")]
110    StoredVersionDeserialization(#[from] std::io::Error),
111    #[error("Stored state version {stored} != args `from_version` {expected}")]
112    StoredVersionMismatch { stored: u32, expected: u32 },
113    #[error("Failed to deserialize old state")]
114    FailedToDeserializeOldState,
115    #[error("Failed to transform old state")]
116    Transformation(E),
117}
118
119pub trait Migrator {
120    fn run(self);
121}
122
123#[ext_contract]
124pub trait MigrateExternalInterface {
125    fn get_stored_state_version() -> u32;
126    fn get_target_state_version() -> u32;
127    fn needs_migration() -> bool;
128}
129
130#[macro_export]
131macro_rules! impl_versioned_state {
132    ($contract: ident, $current_state: ty, $migrations: ty) => {
133        #[::near_sdk::near]
134        impl $crate::versioned_state::MigrateExternalInterface for $contract {
135            fn get_stored_state_version() -> u32 {
136                $crate::versioned_state::read_state_version()
137                    .unwrap_or_else(|e| ::near_sdk::env::panic_str(&e.to_string()))
138            }
139
140            fn get_target_state_version() -> u32 {
141                <$current_state as $crate::versioned_state::StateVersion>::VERSION
142            }
143
144            fn needs_migration() -> bool {
145                <$current_state as $crate::versioned_state::StateVersion>::needs_migration()
146                    .unwrap_or_else(|e| ::near_sdk::env::panic_str(&e.to_string()))
147            }
148        }
149
150        #[cfg_attr(target_arch = "wasm32", unsafe(no_mangle))]
151        #[cfg_attr(not(target_arch = "wasm32"), allow(dead_code))]
152        pub fn migrate() {
153            use ::near_sdk::env;
154            env::setup_panic_hook();
155
156            ::near_sdk::require!(
157                env::predecessor_account_id() == env::current_account_id(),
158                "migrate function is private",
159            );
160
161            let input = env::input().unwrap_or_else(|| env::panic_str("no input"));
162
163            let args: $migrations = ::near_sdk::serde_json::from_slice(&input)
164                .unwrap_or_else(|e| env::panic_str(&e.to_string()));
165
166            $crate::versioned_state::Migrator::run(args);
167        }
168    };
169}
170pub use impl_versioned_state;
171
172#[cfg(test)]
173mod tests {
174    use near_sdk::{test_utils::VMContextBuilder, testing_env};
175
176    use super::*;
177
178    fn context() {
179        testing_env!(VMContextBuilder::new().build());
180    }
181
182    #[test]
183    fn stored_version_defaults_to_zero() {
184        context();
185        assert_eq!(read_state_version().unwrap(), 0);
186    }
187
188    #[test]
189    fn malformed_stored_version_errors() {
190        context();
191        write_state_version(7);
192        env::storage_write(VERSION_KEY, &[1, 2, 3]);
193
194        assert!(read_state_version().is_err());
195    }
196
197    #[test]
198    fn future_stored_version_errors() {
199        context();
200        write_state_version(9);
201
202        let error = TestState::needs_migration().unwrap_err();
203        assert_eq!(error.kind(), ErrorKind::InvalidData);
204        assert!(error
205            .to_string()
206            .contains("Stored state version 9 is newer"));
207    }
208
209    struct TestState;
210
211    impl StateVersion for TestState {
212        const VERSION: u32 = 2;
213        type NewArgs = ();
214
215        fn new((): Self::NewArgs) -> VersionedState<Self> {
216            VersionedState::new(Self)
217        }
218    }
219}