templar_vault_kernel/math/
number.rs

1//! Chain-agnostic Number type for precise 256-bit arithmetic.
2//!
3//! Provides a U256-backed wrapper for overflow-safe calculations.
4
5use core::ops::{Add, Div, Sub};
6
7use derive_more::{From, Into};
8use primitive_types::{U256, U512};
9
10/// Wider type for intermediate calculations.
11pub type WIDE = U512;
12
13/// A 256-bit unsigned integer wrapper for precise arithmetic.
14///
15/// When the `serde` feature is enabled, serializes to/from a decimal string
16/// for compatibility with JSON-based APIs.
17#[cfg_attr(not(target_arch = "wasm32"), derive(Debug))]
18#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, From, Into)]
19pub struct Number(pub U256);
20
21#[cfg(all(feature = "serde", not(feature = "postcard")))]
22mod serde_impl {
23    use super::*;
24    use alloc::string::ToString;
25    use core::fmt;
26    use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
27
28    impl Serialize for Number {
29        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
30        where
31            S: Serializer,
32        {
33            // Serialize as decimal string for JSON compatibility
34            let s = self.0.to_string();
35            serializer.serialize_str(&s)
36        }
37    }
38
39    impl<'de> Deserialize<'de> for Number {
40        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41        where
42            D: Deserializer<'de>,
43        {
44            struct NumberVisitor;
45
46            impl<'de> de::Visitor<'de> for NumberVisitor {
47                type Value = Number;
48
49                fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
50                    formatter.write_str("a decimal string representing a U256")
51                }
52
53                fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
54                where
55                    E: de::Error,
56                {
57                    U256::from_dec_str(v)
58                        .map(Number)
59                        .map_err(|_| E::custom("invalid decimal string for U256"))
60                }
61            }
62
63            deserializer.deserialize_str(NumberVisitor)
64        }
65    }
66}
67
68#[cfg(feature = "postcard")]
69mod postcard_serde_impl {
70    use super::*;
71    use alloc::string::ToString;
72    use core::fmt;
73    use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
74
75    impl Serialize for Number {
76        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77        where
78            S: Serializer,
79        {
80            if serializer.is_human_readable() {
81                return serializer.serialize_str(&self.0.to_string());
82            }
83
84            let mut bytes = [0u8; 32];
85            self.0.write_as_little_endian(&mut bytes);
86            serializer.serialize_bytes(&bytes)
87        }
88    }
89
90    impl<'de> Deserialize<'de> for Number {
91        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
92        where
93            D: Deserializer<'de>,
94        {
95            struct NumberVisitor;
96
97            impl<'de> de::Visitor<'de> for NumberVisitor {
98                type Value = Number;
99
100                fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
101                    formatter.write_str(
102                        "a decimal string representing a U256 or 32 bytes little-endian U256",
103                    )
104                }
105
106                fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
107                where
108                    E: de::Error,
109                {
110                    U256::from_dec_str(v)
111                        .map(Number)
112                        .map_err(|_| E::custom("invalid decimal string for U256"))
113                }
114
115                fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
116                where
117                    E: de::Error,
118                {
119                    if v.len() != 32 {
120                        return Err(E::custom("expected exactly 32 bytes for U256"));
121                    }
122                    Ok(Number(U256::from_little_endian(v)))
123                }
124
125                fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
126                where
127                    A: de::SeqAccess<'de>,
128                {
129                    let mut bytes = [0u8; 32];
130                    for (i, byte) in bytes.iter_mut().enumerate() {
131                        *byte = seq
132                            .next_element()?
133                            .ok_or_else(|| de::Error::invalid_length(i, &self))?;
134                    }
135                    Ok(Number(U256::from_little_endian(&bytes)))
136                }
137            }
138
139            if deserializer.is_human_readable() {
140                deserializer.deserialize_str(NumberVisitor)
141            } else {
142                deserializer.deserialize_bytes(NumberVisitor)
143            }
144        }
145    }
146}
147
148#[cfg(feature = "borsh")]
149mod borsh_impl {
150    use super::*;
151    use borsh::{self, BorshDeserialize, BorshSerialize};
152
153    impl BorshSerialize for Number {
154        fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
155            let mut bytes = [0u8; 32];
156            self.0.write_as_little_endian(&mut bytes);
157            writer.write_all(&bytes)
158        }
159    }
160
161    impl BorshDeserialize for Number {
162        fn deserialize_reader<R: borsh::io::Read>(reader: &mut R) -> borsh::io::Result<Self> {
163            let mut bytes = [0u8; 32];
164            reader.read_exact(&mut bytes)?;
165            Ok(Number(U256::from_little_endian(&bytes)))
166        }
167    }
168}
169
170#[cfg(feature = "borsh-schema")]
171mod borsh_schema_impl {
172    use super::*;
173    use alloc::collections::BTreeMap;
174    use borsh::schema::{add_definition, Declaration, Definition};
175    use borsh::BorshSchema;
176
177    impl BorshSchema for Number {
178        fn add_definitions_recursively(definitions: &mut BTreeMap<Declaration, Definition>) {
179            let definition = Definition::Primitive(32);
180            add_definition(Self::declaration(), definition, definitions);
181        }
182
183        fn declaration() -> Declaration {
184            "Number".into()
185        }
186    }
187}
188
189#[cfg(feature = "schemars")]
190mod schemars_impl {
191    use super::*;
192    use alloc::string::ToString;
193    use schemars::r#gen::SchemaGenerator;
194    use schemars::schema::Schema;
195    use schemars::JsonSchema;
196
197    impl JsonSchema for Number {
198        fn schema_name() -> alloc::string::String {
199            "Number".to_string()
200        }
201
202        fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
203            let mut g = schemars::schema::SchemaObject::default();
204            g.metadata().description = Some("256-bit Unsigned Integer".to_string());
205            g.instance_type = Some(schemars::schema::InstanceType::String.into());
206            g.string().pattern = Some("^(0|[1-9][0-9]{0,77})$".to_string());
207            g.into()
208        }
209    }
210}
211
212impl Number {
213    pub const ZERO: Self = Number(U256::zero());
214    pub const ONE: Self = Number(U256::one());
215
216    #[inline]
217    #[must_use]
218    pub const fn zero() -> Self {
219        Self::ZERO
220    }
221
222    #[inline]
223    #[must_use]
224    pub const fn one() -> Self {
225        Self::ONE
226    }
227
228    #[inline]
229    #[must_use]
230    pub fn is_zero(&self) -> bool {
231        self.0.is_zero()
232    }
233
234    #[inline]
235    #[must_use]
236    pub fn is_one(&self) -> bool {
237        self.0 == U256::one()
238    }
239
240    #[inline]
241    #[must_use]
242    pub fn as_u128_trunc(self) -> u128 {
243        let mut b32 = [0u8; 32];
244        self.0.write_as_little_endian(&mut b32);
245        let mut b16 = [0u8; 16];
246        b16.copy_from_slice(&b32[..16]);
247        u128::from_le_bytes(b16)
248    }
249
250    #[inline]
251    #[must_use]
252    pub fn as_u128_saturating(self) -> u128 {
253        if self.0 .0[2] != 0 || self.0 .0[3] != 0 {
254            u128::MAX
255        } else {
256            self.0.as_u128()
257        }
258    }
259
260    #[inline]
261    pub(crate) fn as_u256_trunc(q: U512) -> U256 {
262        let U512(ref limbs) = q;
263        U256([limbs[0], limbs[1], limbs[2], limbs[3]])
264    }
265
266    #[inline]
267    pub(crate) fn as_u128_if_fits(value: U256) -> Option<u128> {
268        let U256(ref limbs) = value;
269        if limbs[2] != 0 || limbs[3] != 0 {
270            return None;
271        }
272        Some((u128::from(limbs[1]) << 64) | u128::from(limbs[0]))
273    }
274
275    #[inline]
276    #[must_use]
277    pub fn saturating_add(self, other: Number) -> Number {
278        Number(self.0.saturating_add(other.0))
279    }
280
281    #[inline]
282    #[must_use]
283    pub fn saturating_sub(self, other: Number) -> Number {
284        Number(self.0.saturating_sub(other.0))
285    }
286
287    #[inline(never)]
288    fn mul_div_with_rounding(x: Number, y: Number, denom: Number, round_up: bool) -> Number {
289        // Fast path: zero inputs
290        if x.is_zero() || y.is_zero() {
291            return Number::zero();
292        }
293        if denom.is_zero() {
294            return Number::zero();
295        }
296        // Fast path: denom == 1 (identity division)
297        if denom.is_one() {
298            return Number(x.0.saturating_mul(y.0));
299        }
300        // Fast path: cancellation when one factor equals denom
301        if x.0 == denom.0 {
302            return y;
303        }
304        if y.0 == denom.0 {
305            return x;
306        }
307        if let (Some(x128), Some(y128), Some(denom128)) = (
308            Self::as_u128_if_fits(x.0),
309            Self::as_u128_if_fits(y.0),
310            Self::as_u128_if_fits(denom.0),
311        ) {
312            if let Some(prod) = x128.checked_mul(y128) {
313                let q = prod / denom128;
314                if !round_up {
315                    return Number::from(q);
316                }
317                let r = prod % denom128;
318                return if r == 0 {
319                    Number::from(q)
320                } else {
321                    Number::from(q.saturating_add(1))
322                };
323            }
324        }
325        // General path: use U512 for overflow-safe multiplication
326        let prod = x.0.full_mul(y.0);
327        let d = U512::from(denom.0);
328        let q = prod / d;
329        let base = Number(Self::as_u256_trunc(q));
330        if !round_up {
331            return base;
332        }
333        let r = prod % d;
334        if r.is_zero() {
335            base
336        } else {
337            base.saturating_add(Number::one())
338        }
339    }
340
341    #[inline]
342    #[must_use]
343    pub fn mul_div_floor(x: Number, y: Number, denom: Number) -> Number {
344        Self::mul_div_with_rounding(x, y, denom, false)
345    }
346
347    #[inline]
348    #[must_use]
349    pub fn mul_div_ceil(x: Number, y: Number, denom: Number) -> Number {
350        Self::mul_div_with_rounding(x, y, denom, true)
351    }
352}
353
354impl From<u128> for Number {
355    #[inline]
356    fn from(v: u128) -> Self {
357        Number(U256::from(v))
358    }
359}
360impl From<Number> for u128 {
361    #[inline]
362    fn from(n: Number) -> u128 {
363        n.as_u128_trunc()
364    }
365}
366impl Div<u128> for Number {
367    type Output = Number;
368    #[inline]
369    fn div(self, rhs: u128) -> Number {
370        Number(self.0 / U256::from(rhs))
371    }
372}
373impl Div<U256> for Number {
374    type Output = Number;
375    #[inline]
376    fn div(self, rhs: U256) -> Number {
377        Number(self.0 / rhs)
378    }
379}
380impl Div<Number> for Number {
381    type Output = Number;
382    #[inline]
383    fn div(self, rhs: Number) -> Number {
384        Number(self.0 / rhs.0)
385    }
386}
387impl Add<Number> for Number {
388    type Output = Number;
389    #[inline]
390    fn add(self, rhs: Number) -> Number {
391        Number(self.0 + rhs.0)
392    }
393}
394impl Sub<Number> for Number {
395    type Output = Number;
396    #[inline]
397    fn sub(self, rhs: Number) -> Number {
398        Number(self.0 - rhs.0)
399    }
400}
401
402#[cfg(test)]
403mod tests;