templar_common/
versioned_state.rs1use std::io::{Error, ErrorKind};
2
3use borsh::{BorshDeserialize, BorshSerialize};
4use near_sdk::{env, ext_contract, near};
5
6const VERSION_KEY: &[u8] = b"__v";
7
8pub fn write_state_version(version: u32) {
10 env::storage_write(VERSION_KEY, &version.to_le_bytes());
11}
12
13pub fn read_state_version() -> Result<u32, std::io::Error> {
15 let Some(bytes) = env::storage_read(VERSION_KEY) else {
16 return Ok(0);
17 };
18
19 borsh::from_slice(&bytes)
20}
21
22#[derive(Debug)]
23#[near(serializers = [borsh])]
24pub struct VersionedState<T: StateVersion>(T);
25
26impl<T: StateVersion> VersionedState<T> {
27 pub fn new(state: T) -> Self {
28 write_state_version(T::VERSION);
29 Self(state)
30 }
31
32 pub fn version(&self) -> u32 {
33 T::VERSION
34 }
35}
36
37impl<T: StateVersion> std::ops::Deref for VersionedState<T> {
38 type Target = T;
39
40 fn deref(&self) -> &Self::Target {
41 &self.0
42 }
43}
44
45impl<T: StateVersion> std::ops::DerefMut for VersionedState<T> {
46 fn deref_mut(&mut self) -> &mut Self::Target {
47 &mut self.0
48 }
49}
50
51pub trait StateVersion {
52 const VERSION: u32;
53 type NewArgs;
54
55 fn new(args: Self::NewArgs) -> VersionedState<Self>
56 where
57 Self: Sized;
58
59 fn needs_migration() -> Result<bool, std::io::Error> {
60 let stored = read_state_version()?;
61 if stored > Self::VERSION {
62 return Err(Error::new(
63 ErrorKind::InvalidData,
64 format!(
65 "Stored state version {stored} is newer than supported version {}",
66 Self::VERSION
67 ),
68 ));
69 }
70
71 Ok(stored < Self::VERSION)
72 }
73}
74
75pub trait StateTransformer {
76 type Input: StateVersion + BorshDeserialize;
77 type Output: StateVersion + BorshSerialize;
78 type Error;
79
80 fn input_version(&self) -> u32 {
81 Self::Input::VERSION
82 }
83
84 fn output_version(&self) -> u32 {
85 Self::Output::VERSION
86 }
87
88 fn run(&self) -> Result<Self::Output, MigrationError<Self::Error>> {
89 let stored = read_state_version()?;
90 let expected = self.input_version();
91 if stored != expected {
92 return Err(MigrationError::StoredVersionMismatch { stored, expected });
93 }
94 let old_state =
95 env::state_read::<Self::Input>().ok_or(MigrationError::FailedToDeserializeOldState)?;
96 let new_state = self
97 .transform(old_state)
98 .map_err(MigrationError::Transformation)?;
99 env::state_write(&new_state);
100 write_state_version(self.output_version());
101 Ok(new_state)
102 }
103
104 fn transform(&self, input: Self::Input) -> Result<Self::Output, Self::Error>;
105}
106
107#[derive(thiserror::Error, Debug)]
108pub enum MigrationError<E> {
109 #[error("Failed to deserialize stored state version: {0}")]
110 StoredVersionDeserialization(#[from] std::io::Error),
111 #[error("Stored state version {stored} != args `from_version` {expected}")]
112 StoredVersionMismatch { stored: u32, expected: u32 },
113 #[error("Failed to deserialize old state")]
114 FailedToDeserializeOldState,
115 #[error("Failed to transform old state")]
116 Transformation(E),
117}
118
119pub trait Migrator {
120 fn run(self);
121}
122
123#[ext_contract]
124pub trait MigrateExternalInterface {
125 fn get_stored_state_version() -> u32;
126 fn get_target_state_version() -> u32;
127 fn needs_migration() -> bool;
128}
129
130#[macro_export]
131macro_rules! impl_versioned_state {
132 ($contract: ident, $current_state: ty, $migrations: ty) => {
133 #[::near_sdk::near]
134 impl $crate::versioned_state::MigrateExternalInterface for $contract {
135 fn get_stored_state_version() -> u32 {
136 $crate::versioned_state::read_state_version()
137 .unwrap_or_else(|e| ::near_sdk::env::panic_str(&e.to_string()))
138 }
139
140 fn get_target_state_version() -> u32 {
141 <$current_state as $crate::versioned_state::StateVersion>::VERSION
142 }
143
144 fn needs_migration() -> bool {
145 <$current_state as $crate::versioned_state::StateVersion>::needs_migration()
146 .unwrap_or_else(|e| ::near_sdk::env::panic_str(&e.to_string()))
147 }
148 }
149
150 #[cfg_attr(target_arch = "wasm32", unsafe(no_mangle))]
151 #[cfg_attr(not(target_arch = "wasm32"), allow(dead_code))]
152 pub fn migrate() {
153 use ::near_sdk::env;
154 env::setup_panic_hook();
155
156 ::near_sdk::require!(
157 env::predecessor_account_id() == env::current_account_id(),
158 "migrate function is private",
159 );
160
161 let input = env::input().unwrap_or_else(|| env::panic_str("no input"));
162
163 let args: $migrations = ::near_sdk::serde_json::from_slice(&input)
164 .unwrap_or_else(|e| env::panic_str(&e.to_string()));
165
166 $crate::versioned_state::Migrator::run(args);
167 }
168 };
169}
170pub use impl_versioned_state;
171
172#[cfg(test)]
173mod tests {
174 use near_sdk::{test_utils::VMContextBuilder, testing_env};
175
176 use super::*;
177
178 fn context() {
179 testing_env!(VMContextBuilder::new().build());
180 }
181
182 #[test]
183 fn stored_version_defaults_to_zero() {
184 context();
185 assert_eq!(read_state_version().unwrap(), 0);
186 }
187
188 #[test]
189 fn malformed_stored_version_errors() {
190 context();
191 write_state_version(7);
192 env::storage_write(VERSION_KEY, &[1, 2, 3]);
193
194 assert!(read_state_version().is_err());
195 }
196
197 #[test]
198 fn future_stored_version_errors() {
199 context();
200 write_state_version(9);
201
202 let error = TestState::needs_migration().unwrap_err();
203 assert_eq!(error.kind(), ErrorKind::InvalidData);
204 assert!(error
205 .to_string()
206 .contains("Stored state version 9 is newer"));
207 }
208
209 struct TestState;
210
211 impl StateVersion for TestState {
212 const VERSION: u32 = 2;
213 type NewArgs = ();
214
215 fn new((): Self::NewArgs) -> VersionedState<Self> {
216 VersionedState::new(Self)
217 }
218 }
219}