templar_common/oracle/proxy/
aggregator.rs

1use near_sdk::near;
2
3use crate::{
4    oracle::pyth::{self, PythTimestamp},
5    time::Nanoseconds,
6};
7
8/// Calculates the weighted median of a sorted list of weighted items.
9///
10/// If all of the weights are zero, returns the first item.
11///
12/// Only definitely correct for lists where `sum(weights)` does not overflow `u32`.
13fn weighted_median_low<T>(sorted_weighted_items: &[(T, u32)]) -> usize {
14    if sorted_weighted_items.len() == 1 {
15        return 0;
16    }
17
18    let mut lo = 0;
19    let mut hi = sorted_weighted_items.len() - 1;
20    let mut acc: u32 = 0;
21
22    while lo < hi {
23        acc = acc.saturating_add(sorted_weighted_items[lo].1);
24        lo += 1;
25
26        while acc >= sorted_weighted_items[hi].1 && hi != 0 {
27            acc = acc.saturating_sub(sorted_weighted_items[hi].1);
28            hi -= 1;
29        }
30    }
31
32    lo.min(hi)
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36#[near(serializers = [json, borsh])]
37pub struct Aggregator {
38    pub method: AggregationMethod,
39    pub filter: Filter,
40}
41
42impl Aggregator {
43    pub fn median_low(filter: Filter) -> Self {
44        Self {
45            method: AggregationMethod::MedianLow,
46            filter,
47        }
48    }
49
50    pub fn priority(filter: Filter) -> Self {
51        Self {
52            method: AggregationMethod::Priority,
53            filter,
54        }
55    }
56
57    pub fn aggregate(
58        &self,
59        prices: &[(pyth::Price, u32)],
60        now: Nanoseconds,
61    ) -> Option<SpecificPrice> {
62        let prices_filtered = prices
63            .iter()
64            .filter(|p| {
65                let Some(published) = Nanoseconds::try_from_pyth(p.0.publish_time) else {
66                    return false;
67                };
68
69                if now >= published {
70                    self.filter
71                        .max_age
72                        .is_none_or(|max| now.saturating_sub(published) <= max)
73                } else {
74                    self.filter
75                        .max_clock_drift
76                        .is_none_or(|max| published.saturating_sub(now) <= max)
77                }
78            })
79            .collect::<Vec<_>>();
80
81        if prices_filtered.len() < self.filter.min_sources.unwrap_or(1).max(1) as usize {
82            return None;
83        }
84
85        let mut values = prices_filtered
86            .into_iter()
87            .flat_map(|(price, weight)| {
88                // Split apart prices so that we don't need to worry about confidence when sorting.
89                let [lower, upper] = SpecificPrice::split(price);
90                [(lower, *weight), (upper, *weight)]
91            })
92            .collect::<Vec<_>>();
93
94        if values.is_empty() {
95            return None;
96        }
97
98        match &self.method {
99            AggregationMethod::MedianLow => {
100                values.sort_unstable();
101                Some(values.swap_remove(weighted_median_low(&values)).0)
102            }
103            AggregationMethod::Priority => {
104                let mut highest_weighted_ix = 0;
105                for (i, (_p, w)) in values.iter().enumerate().skip(1) {
106                    if *w > values[highest_weighted_ix].1 {
107                        highest_weighted_ix = i;
108                    }
109                }
110                Some(values.swap_remove(highest_weighted_ix).0)
111            }
112        }
113    }
114}
115
116#[derive(Debug, Clone, Eq)]
117pub struct SpecificPrice {
118    pub value: i64,
119    pub exponent: i32,
120    pub publish_time: PythTimestamp,
121}
122
123impl From<SpecificPrice> for pyth::Price {
124    fn from(s: SpecificPrice) -> Self {
125        Self {
126            price: s.value.into(),
127            conf: 0.into(),
128            expo: s.exponent,
129            publish_time: s.publish_time,
130        }
131    }
132}
133
134impl SpecificPrice {
135    pub fn split(price: &pyth::Price) -> [Self; 2] {
136        let conf = i64::try_from(price.conf.0).unwrap_or(i64::MAX);
137        [
138            Self {
139                value: price.price.0.saturating_sub(conf),
140                exponent: price.expo,
141                publish_time: price.publish_time,
142            },
143            Self {
144                value: price.price.0.saturating_add(conf),
145                exponent: price.expo,
146                publish_time: price.publish_time,
147            },
148        ]
149    }
150}
151
152impl PartialEq for SpecificPrice {
153    fn eq(&self, other: &Self) -> bool {
154        self.cmp(other) == std::cmp::Ordering::Equal
155    }
156}
157
158impl PartialOrd for SpecificPrice {
159    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
160        Some(self.cmp(other))
161    }
162}
163
164impl Ord for SpecificPrice {
165    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
166        let expo_diff = self.exponent.abs_diff(other.exponent);
167        let scale = 10i128.saturating_pow(expo_diff);
168        let (lhs, rhs) = if self.exponent >= other.exponent {
169            (
170                i128::from(self.value).saturating_mul(scale),
171                i128::from(other.value),
172            )
173        } else {
174            (
175                i128::from(self.value),
176                i128::from(other.value).saturating_mul(scale),
177            )
178        };
179        lhs.cmp(&rhs)
180    }
181}
182
183/// Aggregation method for the price oracle.
184#[derive(Debug, Clone, PartialEq, Eq)]
185#[near(serializers = [json, borsh])]
186pub enum AggregationMethod {
187    /// Selects the median value from the sources, selecting the lower value
188    /// in case of an even number of sources.
189    MedianLow,
190    /// Selects the value of the source with the highest weight.
191    Priority,
192}
193
194/// Filter configuration for the aggregation.
195#[derive(Debug, Clone, PartialEq, Eq, Default)]
196#[near(serializers = [json, borsh])]
197pub struct Filter {
198    /// Maximum age of a price in nanoseconds. If a price is older than this, it will be excluded from the aggregation.
199    pub max_age: Option<Nanoseconds>,
200    /// Maximum clock drift in nanoseconds. This is the future-analog of `max_age`.
201    pub max_clock_drift: Option<Nanoseconds>,
202    /// Minimum number of sources required for the aggregation to produce a result.
203    ///
204    /// For example, if the proxy has a Pyth source and a RedStone source, and `min_sources` is set to `Some(2)`,
205    /// the aggregation will only produce a result if both oracles provide a price.
206    pub min_sources: Option<u32>,
207}
208
209#[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)]
210#[cfg(test)]
211mod tests {
212    use near_sdk::json_types::{I64, U64};
213
214    use super::*;
215
216    fn sp(value: i64, exponent: i32) -> SpecificPrice {
217        SpecificPrice {
218            value,
219            exponent,
220            publish_time: PythTimestamp::from_secs(0),
221        }
222    }
223
224    fn secs(s: i64) -> PythTimestamp {
225        PythTimestamp::from_secs(s)
226    }
227
228    // --- SpecificPrice::cmp ---
229
230    #[rstest::rstest]
231    #[test]
232    // Same exponent: direct comparison.
233    #[case(sp(100, -4), sp(200, -4), std::cmp::Ordering::Less)]
234    #[case(sp(200, -4), sp(200, -4), std::cmp::Ordering::Equal)]
235    #[case(sp(300, -4), sp(200, -4), std::cmp::Ordering::Greater)]
236    // Different exponents, equal real values: 1e-3 == 10e-4.
237    #[case(sp(1, -3), sp(10, -4), std::cmp::Ordering::Equal)]
238    #[case(sp(10, -4), sp(1, -3), std::cmp::Ordering::Equal)]
239    // Different exponents, unequal: 1e-3 vs 9e-4 and 11e-4.
240    #[case(sp(1, -3), sp(9, -4), std::cmp::Ordering::Greater)]
241    #[case(sp(1, -3), sp(11, -4), std::cmp::Ordering::Less)]
242    // Negative values.
243    #[case(sp(-100, -4), sp(-200, -4), std::cmp::Ordering::Greater)]
244    #[case(sp(-1, -3), sp(-10, -4), std::cmp::Ordering::Equal)]
245    #[case(sp(-1, -3), sp(-9, -4), std::cmp::Ordering::Less)]
246    // Zero.
247    #[case(sp(0, -4), sp(0, 4), std::cmp::Ordering::Equal)]
248    #[case(sp(0, -4), sp(1, -4), std::cmp::Ordering::Less)]
249    // Large expo_diff (>= 39): saturating_mul kicks in.
250    // Any positive value with expo_diff=39 saturates to i128::MAX, dominating any finite rhs.
251    #[case(sp(1, 39), sp(1, 0), std::cmp::Ordering::Greater)]
252    #[case(sp(0, 39), sp(1, 0), std::cmp::Ordering::Less)]
253    #[case(sp(1, 0), sp(1, 39), std::cmp::Ordering::Less)]
254    // expo_diff = 38 is the last precise case (10^38 < i128::MAX).
255    #[case(sp(1, 38), sp(1, 0), std::cmp::Ordering::Greater)]
256    fn specific_price_cmp(
257        #[case] a: SpecificPrice,
258        #[case] b: SpecificPrice,
259        #[case] expected: std::cmp::Ordering,
260    ) {
261        assert_eq!(a.cmp(&b), expected);
262    }
263
264    fn price(value: i64, conf: u64, publish_time: PythTimestamp) -> pyth::Price {
265        pyth::Price {
266            price: I64(value),
267            conf: U64(conf),
268            expo: -6,
269            publish_time,
270        }
271    }
272
273    #[test]
274    fn aggregate_empty_returns_none() {
275        assert!(Aggregator::median_low(Filter::default())
276            .aggregate(&[], Nanoseconds::zero())
277            .is_none());
278    }
279
280    #[test]
281    fn aggregate_single_price_no_conf() {
282        // conf=0 means lower==upper==value, so the median is exactly the price value.
283        let result = Aggregator::median_low(Filter::default())
284            .aggregate(&[(price(1_000_000, 0, secs(0)), 1)], Nanoseconds::zero());
285        assert_eq!(result.unwrap().value, 1_000_000);
286    }
287
288    #[test]
289    fn aggregate_median_of_three() {
290        // Three equal-weight prices: median should be the middle value.
291        let prices = [
292            (price(1_000_000, 0, secs(0)), 1),
293            (price(2_000_000, 0, secs(0)), 1),
294            (price(3_000_000, 0, secs(0)), 1),
295        ];
296        let result =
297            Aggregator::median_low(Filter::default()).aggregate(&prices, Nanoseconds::zero());
298        assert_eq!(result.unwrap().value, 2_000_000);
299    }
300
301    #[test]
302    fn aggregate_min_sources_not_met_returns_none() {
303        let filter = Filter {
304            min_sources: Some(3),
305            ..Default::default()
306        };
307        let prices = [
308            (price(1_000_000, 0, secs(0)), 1),
309            (price(2_000_000, 0, secs(0)), 1),
310        ];
311        assert!(Aggregator::median_low(filter)
312            .aggregate(&prices, Nanoseconds::zero())
313            .is_none());
314    }
315
316    #[test]
317    fn aggregate_min_sources_exactly_met() {
318        let filter = Filter {
319            min_sources: Some(2),
320            ..Default::default()
321        };
322        let prices = [
323            (price(1_000_000, 0, secs(0)), 1),
324            (price(2_000_000, 0, secs(0)), 1),
325        ];
326        assert!(Aggregator::median_low(filter)
327            .aggregate(&prices, Nanoseconds::zero())
328            .is_some());
329    }
330
331    #[test]
332    fn aggregate_min_sources_applies_after_time_filtering() {
333        let filter = Filter {
334            min_sources: Some(2),
335            max_age: Some(Nanoseconds::from_secs(500)),
336            ..Default::default()
337        };
338        let prices = [
339            (price(1_000_000, 0, secs(1_000)), 1),
340            (price(2_000_000, 0, secs(100)), 1),
341        ];
342        assert!(Aggregator::median_low(filter)
343            .aggregate(&prices, Nanoseconds::from_secs(1_000))
344            .is_none());
345    }
346
347    #[rstest::rstest]
348    #[test]
349    #[case::one_under_included(501, 1000, 500, true)]
350    #[case::exactly_at_limit_included(500, 1000, 500, true)]
351    #[case::one_over_excluded(499, 1000, 500, false)]
352    fn aggregate_max_age_boundary(
353        #[case] publish_time_s: i64,
354        #[case] now_s: i64,
355        #[case] max_age_s: u64,
356        #[case] included: bool,
357    ) {
358        // Use two prices: the one under test plus a fresh anchor so aggregate never returns None.
359        let anchor = (price(9_999_999, 0, secs(now_s)), 1);
360        let under_test = (price(1_000_000, 0, secs(publish_time_s)), 1);
361        let filter = Filter {
362            max_age: Some(Nanoseconds::from_secs(max_age_s)),
363            ..Default::default()
364        };
365        let result = Aggregator::median_low(filter)
366            .aggregate(&[under_test, anchor], Nanoseconds::from_secs(now_s as u64))
367            .unwrap();
368        if included {
369            // Median of [1_000_000, 9_999_999] — the lower value wins median_low.
370            assert_eq!(result.value, 1_000_000);
371        } else {
372            // Only the anchor survives filtering.
373            assert_eq!(result.value, 9_999_999);
374        }
375    }
376
377    #[rstest::rstest]
378    #[test]
379    #[case::exactly_at_limit_included(1500, 1000, 500, true)]
380    #[case::one_over_excluded(1501, 1000, 500, false)]
381    fn aggregate_max_clock_drift_boundary(
382        #[case] publish_time_s: i64,
383        #[case] now_s: i64,
384        #[case] max_clock_drift_s: u64,
385        #[case] included: bool,
386    ) {
387        let anchor = (price(9_999_999, 0, secs(now_s)), 1);
388        let under_test = (price(1_000_000, 0, secs(publish_time_s)), 1);
389        let filter = Filter {
390            max_clock_drift: Some(Nanoseconds::from_secs(max_clock_drift_s)),
391            ..Default::default()
392        };
393        let result = Aggregator::median_low(filter)
394            .aggregate(&[under_test, anchor], Nanoseconds::from_secs(now_s as u64))
395            .unwrap();
396        if included {
397            assert_eq!(result.value, 1_000_000);
398        } else {
399            assert_eq!(result.value, 9_999_999);
400        }
401    }
402
403    #[test]
404    fn aggregate_negative_publish_time_excluded() {
405        // Negative publish_time can't be converted to u64, so the price is filtered out.
406        let anchor = (price(9_999_999, 0, secs(1000)), 1);
407        let negative_time = (price(1_000_000, 0, secs(-1)), 1);
408        let filter = Filter {
409            max_age: Some(Nanoseconds::from_ms(500)),
410            ..Default::default()
411        };
412        let result = Aggregator::median_low(filter)
413            .aggregate(&[negative_time, anchor], Nanoseconds::from_ms(1000))
414            .unwrap();
415        assert_eq!(result.value, 9_999_999);
416    }
417
418    // --- Priority aggregation ---
419
420    #[test]
421    fn priority_empty_returns_none() {
422        assert!(Aggregator::priority(Filter::default())
423            .aggregate(&[], Nanoseconds::zero())
424            .is_none());
425    }
426
427    #[test]
428    fn priority_single_price() {
429        let result = Aggregator::priority(Filter::default())
430            .aggregate(&[(price(1_000_000, 0, secs(0)), 1)], Nanoseconds::zero());
431        assert_eq!(result.unwrap().value, 1_000_000);
432    }
433
434    #[test]
435    fn priority_selects_highest_weight() {
436        let prices = [
437            (price(1_000_000, 0, secs(0)), 1),
438            (price(2_000_000, 0, secs(0)), 10),
439            (price(3_000_000, 0, secs(0)), 5),
440        ];
441        let result = Aggregator::priority(Filter::default())
442            .aggregate(&prices, Nanoseconds::zero())
443            .unwrap();
444        // Highest weight is 10 → price 2_000_000 (lower bound with conf=0).
445        assert_eq!(result.value, 2_000_000);
446    }
447
448    #[test]
449    fn priority_equal_weights_selects_first() {
450        let prices = [
451            (price(1_000_000, 0, secs(0)), 5),
452            (price(2_000_000, 0, secs(0)), 5),
453            (price(3_000_000, 0, secs(0)), 5),
454        ];
455        let result = Aggregator::priority(Filter::default())
456            .aggregate(&prices, Nanoseconds::zero())
457            .unwrap();
458        // All weights equal → first entry wins (lower bound of first price).
459        assert_eq!(result.value, 1_000_000);
460    }
461
462    #[test]
463    fn priority_with_confidence_returns_lower_bound() {
464        // conf=100 splits into lower (900) and upper (1100), both weight 10.
465        // The lower bound comes first in iteration, so it's selected.
466        let prices = [
467            (price(1_000, 100, secs(0)), 10),
468            (price(2_000, 0, secs(0)), 1),
469        ];
470        let result = Aggregator::priority(Filter::default())
471            .aggregate(&prices, Nanoseconds::zero())
472            .unwrap();
473        assert_eq!(result.value, 1_000 - 100);
474    }
475
476    #[test]
477    fn priority_respects_max_age_filter() {
478        let filter = Filter {
479            max_age: Some(Nanoseconds::from_secs(500)),
480            ..Default::default()
481        };
482        // High-weight price is stale, low-weight price is fresh.
483        let prices = [
484            (price(1_000_000, 0, secs(0)), 100), // stale at now=1000, max_age=500
485            (price(2_000_000, 0, secs(900)), 1), // fresh
486        ];
487        let result = Aggregator::priority(filter)
488            .aggregate(&prices, Nanoseconds::from_secs(1000))
489            .unwrap();
490        // Stale price filtered out, only fresh one remains.
491        assert_eq!(result.value, 2_000_000);
492    }
493
494    #[test]
495    fn priority_min_sources_not_met_returns_none() {
496        let filter = Filter {
497            min_sources: Some(3),
498            ..Default::default()
499        };
500        let prices = [
501            (price(1_000_000, 0, secs(0)), 10),
502            (price(2_000_000, 0, secs(0)), 1),
503        ];
504        assert!(Aggregator::priority(filter)
505            .aggregate(&prices, Nanoseconds::zero())
506            .is_none());
507    }
508
509    // --- weighted_median_low ---
510
511    #[rstest::rstest]
512    #[test]
513    #[case(&[("a", 1)], "a")]
514    #[case(&[("a", 1), ("b", 1), ("c", 1)], "b")]
515    #[case(&[("a", 1), ("b", 1), ("c", 1), ("d", 1)], "b")]
516    #[case(&[("a", 2), ("b", 1), ("c", 1), ("d", 1)], "b")]
517    #[case(&[("a", 1), ("b", 1), ("c", 1), ("d", 2)], "c")]
518    #[case(&[("a", 10), ("b", 2), ("c", 6), ("d", 2)], "a")]
519    #[case(&[("a", 1), ("b", 10000), ("c", 1)], "b")]
520    #[case(&[("a", 2), ("b", 1), ("c", 1)], "a")]
521    #[case(&[("a", u32::MAX), ("b", u32::MAX), ("c", u32::MAX)], "b")]
522    #[case(&[("a", u32::MAX), ("b", 0), ("c", u32::MAX)], "a")]
523    #[case(&[("a", 0), ("b", 0), ("c", 0), ("d", 0)], "a")]
524    #[case(&[("a", 0), ("b", 0), ("c", 0), ("d", 1)], "d")]
525    #[case(&[("a", 0), ("b", 1), ("c", 0), ("d", 1)], "b")]
526    fn test_weighted_median(#[case] list: &[(&str, u32)], #[case] expected: &str) {
527        let item = list[weighted_median_low(list)].0;
528        assert_eq!(item, expected);
529    }
530}