templar_common/versioned_state/
core.rs1use std::io::{Error, ErrorKind};
2
3use borsh::{BorshDeserialize, BorshSerialize};
4use near_sdk::env;
5
6const VERSION_KEY: &[u8] = b"__v";
7
8pub fn write_state_version(version: u32) {
9 env::storage_write(VERSION_KEY, &version.to_le_bytes());
10}
11
12pub fn read_state_version() -> Result<u32, std::io::Error> {
13 let Some(bytes) = env::storage_read(VERSION_KEY) else {
14 return Ok(0);
15 };
16
17 borsh::from_slice(&bytes)
18}
19
20#[derive(Debug)]
21#[near_sdk::near(serializers = [borsh])]
22pub struct VersionedState<T: StateVersion>(T);
23
24impl<T: StateVersion> VersionedState<T> {
25 pub fn new(state: T) -> Self {
26 write_state_version(T::VERSION);
27 Self(state)
28 }
29
30 pub fn version(&self) -> u32 {
31 T::VERSION
32 }
33}
34
35impl<T: StateVersion> std::ops::Deref for VersionedState<T> {
36 type Target = T;
37
38 fn deref(&self) -> &Self::Target {
39 &self.0
40 }
41}
42
43impl<T: StateVersion> std::ops::DerefMut for VersionedState<T> {
44 fn deref_mut(&mut self) -> &mut Self::Target {
45 &mut self.0
46 }
47}
48
49pub trait StateVersion {
50 const VERSION: u32;
51 type NewArgs;
52
53 fn new(args: Self::NewArgs) -> VersionedState<Self>
54 where
55 Self: Sized;
56
57 fn needs_migration() -> Result<bool, std::io::Error> {
58 let stored = read_state_version()?;
59 if stored > Self::VERSION {
60 return Err(Error::new(
61 ErrorKind::InvalidData,
62 format!(
63 "Stored state version {stored} is newer than supported version {}",
64 Self::VERSION
65 ),
66 ));
67 }
68
69 Ok(stored < Self::VERSION)
70 }
71}
72
73pub trait StateTransformer {
74 type Input: StateVersion + BorshDeserialize;
75 type Output: StateVersion + BorshSerialize;
76 type Error;
77
78 fn input_version(&self) -> u32 {
79 Self::Input::VERSION
80 }
81
82 fn output_version(&self) -> u32 {
83 Self::Output::VERSION
84 }
85
86 fn run(&self) -> Result<Self::Output, MigrationError<Self::Error>> {
87 let stored = read_state_version()?;
88 let expected = self.input_version();
89 if stored != expected {
90 return Err(MigrationError::StoredVersionMismatch { stored, expected });
91 }
92 let old_state =
93 env::state_read::<Self::Input>().ok_or(MigrationError::FailedToDeserializeOldState)?;
94 let new_state = self
95 .transform(old_state)
96 .map_err(MigrationError::Transformation)?;
97 env::state_write(&new_state);
98 write_state_version(self.output_version());
99 Ok(new_state)
100 }
101
102 fn transform(&self, input: Self::Input) -> Result<Self::Output, Self::Error>;
103}
104
105#[derive(thiserror::Error, Debug)]
106pub enum MigrationError<E> {
107 #[error("Failed to deserialize stored state version: {0}")]
108 StoredVersionDeserialization(#[from] std::io::Error),
109 #[error("Stored state version {stored} != args `from_version` {expected}")]
110 StoredVersionMismatch { stored: u32, expected: u32 },
111 #[error("Failed to deserialize old state")]
112 FailedToDeserializeOldState,
113 #[error("Failed to transform old state")]
114 Transformation(E),
115}
116
117pub trait Migrator {
118 fn run(self);
119}
120
121#[cfg(test)]
122mod tests {
123 use near_sdk::{test_utils::VMContextBuilder, testing_env};
124
125 use super::*;
126
127 fn context() {
128 testing_env!(VMContextBuilder::new().build());
129 }
130
131 #[test]
132 fn stored_version_defaults_to_zero() {
133 context();
134 assert_eq!(read_state_version().unwrap(), 0);
135 }
136
137 #[test]
138 fn malformed_stored_version_errors() {
139 context();
140 write_state_version(7);
141 env::storage_write(VERSION_KEY, &[1, 2, 3]);
142
143 assert!(read_state_version().is_err());
144 }
145
146 #[test]
147 fn future_stored_version_errors() {
148 context();
149 write_state_version(9);
150
151 let error = TestState::needs_migration().unwrap_err();
152 assert_eq!(error.kind(), ErrorKind::InvalidData);
153 assert!(error
154 .to_string()
155 .contains("Stored state version 9 is newer"));
156 }
157
158 struct TestState;
159
160 impl StateVersion for TestState {
161 const VERSION: u32 = 2;
162 type NewArgs = ();
163
164 fn new((): Self::NewArgs) -> VersionedState<Self> {
165 VersionedState::new(Self)
166 }
167 }
168}