templar_vault_kernel/math/
number.rs

1//! Chain-agnostic Number type for precise 256-bit arithmetic.
2//!
3//! Provides a U256-backed wrapper for overflow-safe calculations.
4
5use core::ops::{Add, Div, Sub};
6
7use derive_more::{From, Into};
8use primitive_types::{U256, U512};
9
10/// Wider type for intermediate calculations.
11pub type WIDE = U512;
12
13/// A 256-bit unsigned integer wrapper for precise arithmetic.
14///
15/// When the `serde` feature is enabled, serializes to/from a decimal string
16/// for compatibility with JSON-based APIs.
17#[cfg_attr(not(target_arch = "wasm32"), derive(Debug))]
18#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, From, Into)]
19pub struct Number(pub U256);
20
21#[cfg(all(feature = "serde", not(feature = "postcard")))]
22mod serde_impl {
23    use super::*;
24    use alloc::string::ToString;
25    use core::fmt;
26    use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
27
28    impl Serialize for Number {
29        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
30        where
31            S: Serializer,
32        {
33            // Serialize as decimal string for JSON compatibility
34            let s = self.0.to_string();
35            serializer.serialize_str(&s)
36        }
37    }
38
39    impl<'de> Deserialize<'de> for Number {
40        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41        where
42            D: Deserializer<'de>,
43        {
44            struct NumberVisitor;
45
46            impl de::Visitor<'_> for NumberVisitor {
47                type Value = Number;
48
49                fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
50                    formatter.write_str("a decimal string representing a U256")
51                }
52
53                fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
54                where
55                    E: de::Error,
56                {
57                    U256::from_dec_str(v)
58                        .map(Number)
59                        .map_err(|_| E::custom("invalid decimal string for U256"))
60                }
61            }
62
63            deserializer.deserialize_str(NumberVisitor)
64        }
65    }
66}
67
68#[cfg(feature = "postcard")]
69mod postcard_serde_impl {
70    use super::*;
71    #[cfg(not(feature = "soroban"))]
72    use serde::de;
73    use serde::{Deserialize, Deserializer, Serialize, Serializer};
74
75    impl Serialize for Number {
76        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77        where
78            S: Serializer,
79        {
80            #[cfg(feature = "soroban")]
81            {
82                Number::as_u128_if_fits(self.0)
83                    .ok_or_else(|| {
84                        serde::ser::Error::custom("Number exceeds u128 for Soroban postcard")
85                    })
86                    .and_then(|value| serializer.serialize_u128(value))
87            }
88
89            #[cfg(not(feature = "soroban"))]
90            {
91                let mut bytes = [0u8; 32];
92                self.0.write_as_little_endian(&mut bytes);
93                serializer.serialize_bytes(&bytes)
94            }
95        }
96    }
97
98    impl<'de> Deserialize<'de> for Number {
99        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
100        where
101            D: Deserializer<'de>,
102        {
103            #[cfg(feature = "soroban")]
104            {
105                u128::deserialize(deserializer).map(Number::from)
106            }
107
108            #[cfg(not(feature = "soroban"))]
109            {
110                struct NumberVisitor;
111
112                impl<'de> de::Visitor<'de> for NumberVisitor {
113                    type Value = Number;
114
115                    fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
116                        formatter.write_str("exactly 32 bytes for little-endian U256")
117                    }
118
119                    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
120                    where
121                        E: de::Error,
122                    {
123                        if v.len() != 32 {
124                            return Err(E::custom("expected exactly 32 bytes for U256"));
125                        }
126                        Ok(Number(U256::from_little_endian(v)))
127                    }
128                }
129
130                deserializer.deserialize_bytes(NumberVisitor)
131            }
132        }
133    }
134}
135
136#[cfg(feature = "borsh")]
137mod borsh_impl {
138    use super::*;
139    use borsh::{self, BorshDeserialize, BorshSerialize};
140
141    impl BorshSerialize for Number {
142        fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
143            let mut bytes = [0u8; 32];
144            self.0.write_as_little_endian(&mut bytes);
145            writer.write_all(&bytes)
146        }
147    }
148
149    impl BorshDeserialize for Number {
150        fn deserialize_reader<R: borsh::io::Read>(reader: &mut R) -> borsh::io::Result<Self> {
151            let mut bytes = [0u8; 32];
152            reader.read_exact(&mut bytes)?;
153            Ok(Number(U256::from_little_endian(&bytes)))
154        }
155    }
156}
157
158#[cfg(feature = "borsh-schema")]
159mod borsh_schema_impl {
160    use super::*;
161    use alloc::collections::BTreeMap;
162    use borsh::schema::{add_definition, Declaration, Definition};
163    use borsh::BorshSchema;
164
165    impl BorshSchema for Number {
166        fn add_definitions_recursively(definitions: &mut BTreeMap<Declaration, Definition>) {
167            let definition = Definition::Primitive(32);
168            add_definition(Self::declaration(), definition, definitions);
169        }
170
171        fn declaration() -> Declaration {
172            "Number".into()
173        }
174    }
175}
176
177#[cfg(feature = "schemars")]
178mod schemars_impl {
179    use super::*;
180    use alloc::string::ToString;
181    use schemars::r#gen::SchemaGenerator;
182    use schemars::schema::Schema;
183    use schemars::JsonSchema;
184
185    impl JsonSchema for Number {
186        fn schema_name() -> alloc::string::String {
187            "Number".to_string()
188        }
189
190        fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
191            let mut g = schemars::schema::SchemaObject::default();
192            g.metadata().description = Some("256-bit Unsigned Integer".to_string());
193            g.instance_type = Some(schemars::schema::InstanceType::String.into());
194            g.string().pattern = Some("^(0|[1-9][0-9]{0,77})$".to_string());
195            g.into()
196        }
197    }
198}
199
200impl Number {
201    pub const ZERO: Self = Number(U256::zero());
202    pub const ONE: Self = Number(U256::one());
203
204    #[inline]
205    #[must_use]
206    pub const fn zero() -> Self {
207        Self::ZERO
208    }
209
210    #[inline]
211    #[must_use]
212    pub const fn one() -> Self {
213        Self::ONE
214    }
215
216    #[inline]
217    #[must_use]
218    pub fn is_zero(&self) -> bool {
219        self.0.is_zero()
220    }
221
222    #[inline]
223    #[must_use]
224    pub fn is_one(&self) -> bool {
225        self.0 == U256::one()
226    }
227
228    #[inline]
229    #[must_use]
230    pub fn as_u128_trunc(self) -> u128 {
231        let mut b32 = [0u8; 32];
232        self.0.write_as_little_endian(&mut b32);
233        let mut b16 = [0u8; 16];
234        b16.copy_from_slice(&b32[..16]);
235        u128::from_le_bytes(b16)
236    }
237
238    #[inline]
239    #[must_use]
240    pub fn as_u128_saturating(self) -> u128 {
241        if self.0 .0[2] != 0 || self.0 .0[3] != 0 {
242            u128::MAX
243        } else {
244            self.0.as_u128()
245        }
246    }
247
248    #[inline]
249    pub(crate) fn as_u256_trunc(q: U512) -> U256 {
250        let U512(ref limbs) = q;
251        U256([limbs[0], limbs[1], limbs[2], limbs[3]])
252    }
253
254    #[inline]
255    pub(crate) fn as_u128_if_fits(value: U256) -> Option<u128> {
256        let U256(ref limbs) = value;
257        if limbs[2] != 0 || limbs[3] != 0 {
258            return None;
259        }
260        Some((u128::from(limbs[1]) << 64) | u128::from(limbs[0]))
261    }
262
263    #[inline]
264    #[must_use]
265    pub fn saturating_add(self, other: Number) -> Number {
266        Number(self.0.saturating_add(other.0))
267    }
268
269    #[inline]
270    #[must_use]
271    pub fn saturating_sub(self, other: Number) -> Number {
272        Number(self.0.saturating_sub(other.0))
273    }
274
275    #[inline(never)]
276    fn mul_div_with_rounding(x: Number, y: Number, denom: Number, round_up: bool) -> Number {
277        // Fast path: zero inputs
278        if x.is_zero() || y.is_zero() {
279            return Number::zero();
280        }
281        if denom.is_zero() {
282            return Number::zero();
283        }
284        // Fast path: denom == 1 (identity division)
285        if denom.is_one() {
286            return Number(x.0.saturating_mul(y.0));
287        }
288        // Fast path: cancellation when one factor equals denom
289        if x.0 == denom.0 {
290            return y;
291        }
292        if y.0 == denom.0 {
293            return x;
294        }
295        if let (Some(x128), Some(y128), Some(denom128)) = (
296            Self::as_u128_if_fits(x.0),
297            Self::as_u128_if_fits(y.0),
298            Self::as_u128_if_fits(denom.0),
299        ) {
300            if let Some(prod) = x128.checked_mul(y128) {
301                let q = prod / denom128;
302                if !round_up {
303                    return Number::from(q);
304                }
305                let r = prod % denom128;
306                return if r == 0 {
307                    Number::from(q)
308                } else {
309                    Number::from(q.saturating_add(1))
310                };
311            }
312        }
313        // General path: use U512 for overflow-safe multiplication
314        let prod = x.0.full_mul(y.0);
315        let d = U512::from(denom.0);
316        let q = prod / d;
317        let base = Number(Self::as_u256_trunc(q));
318        if !round_up {
319            return base;
320        }
321        let r = prod % d;
322        if r.is_zero() {
323            base
324        } else {
325            base.saturating_add(Number::one())
326        }
327    }
328
329    #[inline]
330    #[must_use]
331    pub fn mul_div_floor(x: Number, y: Number, denom: Number) -> Number {
332        Self::mul_div_with_rounding(x, y, denom, false)
333    }
334
335    #[inline]
336    #[must_use]
337    pub fn mul_div_ceil(x: Number, y: Number, denom: Number) -> Number {
338        Self::mul_div_with_rounding(x, y, denom, true)
339    }
340}
341
342impl From<u128> for Number {
343    #[inline]
344    fn from(v: u128) -> Self {
345        Number(U256::from(v))
346    }
347}
348impl From<Number> for u128 {
349    #[inline]
350    fn from(n: Number) -> u128 {
351        n.as_u128_trunc()
352    }
353}
354impl Div<u128> for Number {
355    type Output = Number;
356    #[inline]
357    fn div(self, rhs: u128) -> Number {
358        Number(self.0 / U256::from(rhs))
359    }
360}
361impl Div<U256> for Number {
362    type Output = Number;
363    #[inline]
364    fn div(self, rhs: U256) -> Number {
365        Number(self.0 / rhs)
366    }
367}
368impl Div<Number> for Number {
369    type Output = Number;
370    #[inline]
371    fn div(self, rhs: Number) -> Number {
372        Number(self.0 / rhs.0)
373    }
374}
375impl Add<Number> for Number {
376    type Output = Number;
377    #[inline]
378    fn add(self, rhs: Number) -> Number {
379        Number(self.0 + rhs.0)
380    }
381}
382impl Sub<Number> for Number {
383    type Output = Number;
384    #[inline]
385    fn sub(self, rhs: Number) -> Number {
386        Number(self.0 - rhs.0)
387    }
388}
389
390#[cfg(test)]
391mod tests;