templar_vault_kernel/math/
number.rs1use core::ops::{Add, Div, Sub};
6
7use derive_more::{From, Into};
8use primitive_types::{U256, U512};
9
10pub type WIDE = U512;
12
13#[cfg_attr(not(target_arch = "wasm32"), derive(Debug))]
18#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, From, Into)]
19pub struct Number(pub U256);
20
21#[cfg(all(feature = "serde", not(feature = "postcard")))]
22mod serde_impl {
23 use super::*;
24 use alloc::string::ToString;
25 use core::fmt;
26 use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
27
28 impl Serialize for Number {
29 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
30 where
31 S: Serializer,
32 {
33 let s = self.0.to_string();
35 serializer.serialize_str(&s)
36 }
37 }
38
39 impl<'de> Deserialize<'de> for Number {
40 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41 where
42 D: Deserializer<'de>,
43 {
44 struct NumberVisitor;
45
46 impl<'de> de::Visitor<'de> for NumberVisitor {
47 type Value = Number;
48
49 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
50 formatter.write_str("a decimal string representing a U256")
51 }
52
53 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
54 where
55 E: de::Error,
56 {
57 U256::from_dec_str(v)
58 .map(Number)
59 .map_err(|_| E::custom("invalid decimal string for U256"))
60 }
61 }
62
63 deserializer.deserialize_str(NumberVisitor)
64 }
65 }
66}
67
68#[cfg(feature = "postcard")]
69mod postcard_serde_impl {
70 use super::*;
71 use alloc::string::ToString;
72 use core::fmt;
73 use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
74
75 impl Serialize for Number {
76 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77 where
78 S: Serializer,
79 {
80 if serializer.is_human_readable() {
81 return serializer.serialize_str(&self.0.to_string());
82 }
83
84 let mut bytes = [0u8; 32];
85 self.0.write_as_little_endian(&mut bytes);
86 serializer.serialize_bytes(&bytes)
87 }
88 }
89
90 impl<'de> Deserialize<'de> for Number {
91 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
92 where
93 D: Deserializer<'de>,
94 {
95 struct NumberVisitor;
96
97 impl<'de> de::Visitor<'de> for NumberVisitor {
98 type Value = Number;
99
100 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
101 formatter.write_str(
102 "a decimal string representing a U256 or 32 bytes little-endian U256",
103 )
104 }
105
106 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
107 where
108 E: de::Error,
109 {
110 U256::from_dec_str(v)
111 .map(Number)
112 .map_err(|_| E::custom("invalid decimal string for U256"))
113 }
114
115 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
116 where
117 E: de::Error,
118 {
119 if v.len() != 32 {
120 return Err(E::custom("expected exactly 32 bytes for U256"));
121 }
122 Ok(Number(U256::from_little_endian(v)))
123 }
124
125 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
126 where
127 A: de::SeqAccess<'de>,
128 {
129 let mut bytes = [0u8; 32];
130 for (i, byte) in bytes.iter_mut().enumerate() {
131 *byte = seq
132 .next_element()?
133 .ok_or_else(|| de::Error::invalid_length(i, &self))?;
134 }
135 Ok(Number(U256::from_little_endian(&bytes)))
136 }
137 }
138
139 if deserializer.is_human_readable() {
140 deserializer.deserialize_str(NumberVisitor)
141 } else {
142 deserializer.deserialize_bytes(NumberVisitor)
143 }
144 }
145 }
146}
147
148#[cfg(feature = "borsh")]
149mod borsh_impl {
150 use super::*;
151 use borsh::{self, BorshDeserialize, BorshSerialize};
152
153 impl BorshSerialize for Number {
154 fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
155 let mut bytes = [0u8; 32];
156 self.0.write_as_little_endian(&mut bytes);
157 writer.write_all(&bytes)
158 }
159 }
160
161 impl BorshDeserialize for Number {
162 fn deserialize_reader<R: borsh::io::Read>(reader: &mut R) -> borsh::io::Result<Self> {
163 let mut bytes = [0u8; 32];
164 reader.read_exact(&mut bytes)?;
165 Ok(Number(U256::from_little_endian(&bytes)))
166 }
167 }
168}
169
170#[cfg(feature = "borsh-schema")]
171mod borsh_schema_impl {
172 use super::*;
173 use alloc::collections::BTreeMap;
174 use borsh::schema::{add_definition, Declaration, Definition};
175 use borsh::BorshSchema;
176
177 impl BorshSchema for Number {
178 fn add_definitions_recursively(definitions: &mut BTreeMap<Declaration, Definition>) {
179 let definition = Definition::Primitive(32);
180 add_definition(Self::declaration(), definition, definitions);
181 }
182
183 fn declaration() -> Declaration {
184 "Number".into()
185 }
186 }
187}
188
189#[cfg(feature = "schemars")]
190mod schemars_impl {
191 use super::*;
192 use alloc::string::ToString;
193 use schemars::r#gen::SchemaGenerator;
194 use schemars::schema::Schema;
195 use schemars::JsonSchema;
196
197 impl JsonSchema for Number {
198 fn schema_name() -> alloc::string::String {
199 "Number".to_string()
200 }
201
202 fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
203 let mut g = schemars::schema::SchemaObject::default();
204 g.metadata().description = Some("256-bit Unsigned Integer".to_string());
205 g.instance_type = Some(schemars::schema::InstanceType::String.into());
206 g.string().pattern = Some("^(0|[1-9][0-9]{0,77})$".to_string());
207 g.into()
208 }
209 }
210}
211
212impl Number {
213 pub const ZERO: Self = Number(U256::zero());
214 pub const ONE: Self = Number(U256::one());
215
216 #[inline]
217 #[must_use]
218 pub const fn zero() -> Self {
219 Self::ZERO
220 }
221
222 #[inline]
223 #[must_use]
224 pub const fn one() -> Self {
225 Self::ONE
226 }
227
228 #[inline]
229 #[must_use]
230 pub fn is_zero(&self) -> bool {
231 self.0.is_zero()
232 }
233
234 #[inline]
235 #[must_use]
236 pub fn is_one(&self) -> bool {
237 self.0 == U256::one()
238 }
239
240 #[inline]
241 #[must_use]
242 pub fn as_u128_trunc(self) -> u128 {
243 let mut b32 = [0u8; 32];
244 self.0.write_as_little_endian(&mut b32);
245 let mut b16 = [0u8; 16];
246 b16.copy_from_slice(&b32[..16]);
247 u128::from_le_bytes(b16)
248 }
249
250 #[inline]
251 #[must_use]
252 pub fn as_u128_saturating(self) -> u128 {
253 if self.0 .0[2] != 0 || self.0 .0[3] != 0 {
254 u128::MAX
255 } else {
256 self.0.as_u128()
257 }
258 }
259
260 #[inline]
261 pub(crate) fn as_u256_trunc(q: U512) -> U256 {
262 let U512(ref limbs) = q;
263 U256([limbs[0], limbs[1], limbs[2], limbs[3]])
264 }
265
266 #[inline]
267 pub(crate) fn as_u128_if_fits(value: U256) -> Option<u128> {
268 let U256(ref limbs) = value;
269 if limbs[2] != 0 || limbs[3] != 0 {
270 return None;
271 }
272 Some((u128::from(limbs[1]) << 64) | u128::from(limbs[0]))
273 }
274
275 #[inline]
276 #[must_use]
277 pub fn saturating_add(self, other: Number) -> Number {
278 Number(self.0.saturating_add(other.0))
279 }
280
281 #[inline]
282 #[must_use]
283 pub fn saturating_sub(self, other: Number) -> Number {
284 Number(self.0.saturating_sub(other.0))
285 }
286
287 #[inline(never)]
288 fn mul_div_with_rounding(x: Number, y: Number, denom: Number, round_up: bool) -> Number {
289 if x.is_zero() || y.is_zero() {
291 return Number::zero();
292 }
293 if denom.is_zero() {
294 return Number::zero();
295 }
296 if denom.is_one() {
298 return Number(x.0.saturating_mul(y.0));
299 }
300 if x.0 == denom.0 {
302 return y;
303 }
304 if y.0 == denom.0 {
305 return x;
306 }
307 if let (Some(x128), Some(y128), Some(denom128)) = (
308 Self::as_u128_if_fits(x.0),
309 Self::as_u128_if_fits(y.0),
310 Self::as_u128_if_fits(denom.0),
311 ) {
312 if let Some(prod) = x128.checked_mul(y128) {
313 let q = prod / denom128;
314 if !round_up {
315 return Number::from(q);
316 }
317 let r = prod % denom128;
318 return if r == 0 {
319 Number::from(q)
320 } else {
321 Number::from(q.saturating_add(1))
322 };
323 }
324 }
325 let prod = x.0.full_mul(y.0);
327 let d = U512::from(denom.0);
328 let q = prod / d;
329 let base = Number(Self::as_u256_trunc(q));
330 if !round_up {
331 return base;
332 }
333 let r = prod % d;
334 if r.is_zero() {
335 base
336 } else {
337 base.saturating_add(Number::one())
338 }
339 }
340
341 #[inline]
342 #[must_use]
343 pub fn mul_div_floor(x: Number, y: Number, denom: Number) -> Number {
344 Self::mul_div_with_rounding(x, y, denom, false)
345 }
346
347 #[inline]
348 #[must_use]
349 pub fn mul_div_ceil(x: Number, y: Number, denom: Number) -> Number {
350 Self::mul_div_with_rounding(x, y, denom, true)
351 }
352}
353
354impl From<u128> for Number {
355 #[inline]
356 fn from(v: u128) -> Self {
357 Number(U256::from(v))
358 }
359}
360impl From<Number> for u128 {
361 #[inline]
362 fn from(n: Number) -> u128 {
363 n.as_u128_trunc()
364 }
365}
366impl Div<u128> for Number {
367 type Output = Number;
368 #[inline]
369 fn div(self, rhs: u128) -> Number {
370 Number(self.0 / U256::from(rhs))
371 }
372}
373impl Div<U256> for Number {
374 type Output = Number;
375 #[inline]
376 fn div(self, rhs: U256) -> Number {
377 Number(self.0 / rhs)
378 }
379}
380impl Div<Number> for Number {
381 type Output = Number;
382 #[inline]
383 fn div(self, rhs: Number) -> Number {
384 Number(self.0 / rhs.0)
385 }
386}
387impl Add<Number> for Number {
388 type Output = Number;
389 #[inline]
390 fn add(self, rhs: Number) -> Number {
391 Number(self.0 + rhs.0)
392 }
393}
394impl Sub<Number> for Number {
395 type Output = Number;
396 #[inline]
397 fn sub(self, rhs: Number) -> Number {
398 Number(self.0 - rhs.0)
399 }
400}
401
402#[cfg(test)]
403mod tests;