templar_curator_primitives/policy/supply_queue/
mod.rs

1//! Supply queue for managing pending allocation requests.
2
3use alloc::vec::Vec;
4use core::num::NonZeroU32;
5use templar_vault_kernel::{TargetId, TimestampNs};
6
7use super::market_lock::MarketLeaseRegistry;
8
9#[templar_vault_macros::vault_derive(borsh, serde)]
10#[derive(Clone, PartialEq, Eq)]
11pub struct SupplyQueueEntry {
12    pub target_id: TargetId,
13    pub amount: u128,
14    pub priority: u8,
15}
16
17impl SupplyQueueEntry {
18    pub fn new(target_id: TargetId, amount: u128) -> Result<Self, SupplyQueueError> {
19        Self::new_with_priority(target_id, amount, 0)
20    }
21
22    pub fn new_with_priority(
23        target_id: TargetId,
24        amount: u128,
25        priority: u8,
26    ) -> Result<Self, SupplyQueueError> {
27        if amount == 0 {
28            return Err(SupplyQueueError::ZeroAmount);
29        }
30
31        Ok(Self {
32            target_id,
33            amount,
34            priority,
35        })
36    }
37
38    fn validate(&self) -> Result<(), SupplyQueueError> {
39        if self.amount == 0 {
40            return Err(SupplyQueueError::ZeroAmount);
41        }
42
43        Ok(())
44    }
45}
46
47impl TryFrom<(TargetId, u128)> for SupplyQueueEntry {
48    type Error = SupplyQueueError;
49
50    fn try_from(value: (TargetId, u128)) -> Result<Self, Self::Error> {
51        Self::new(value.0, value.1)
52    }
53}
54
55#[templar_vault_macros::vault_derive(borsh, serde)]
56#[derive(Clone, PartialEq, Eq)]
57pub struct SupplyQueue {
58    buckets: Vec<Vec<SupplyQueueEntry>>,
59    len: u32,
60    max_length: Option<u32>,
61}
62
63impl Default for SupplyQueue {
64    fn default() -> Self {
65        Self::unbounded()
66    }
67}
68
69impl SupplyQueue {
70    #[must_use]
71    pub fn new(max_length: Option<NonZeroU32>) -> Self {
72        Self {
73            buckets: alloc::vec![Vec::new(); usize::from(u8::MAX) + 1],
74            len: 0,
75            max_length: max_length.map(NonZeroU32::get),
76        }
77    }
78
79    #[must_use]
80    pub fn unbounded() -> Self {
81        Self::new(None)
82    }
83
84    #[must_use]
85    pub fn bounded(max_length: NonZeroU32) -> Self {
86        Self::new(Some(max_length))
87    }
88
89    pub fn try_from_entries(
90        entries: Vec<SupplyQueueEntry>,
91        max_length: Option<NonZeroU32>,
92    ) -> Result<Self, SupplyQueueError> {
93        let mut queue = Self::new(max_length);
94        for entry in entries {
95            queue.enqueue(entry)?;
96        }
97        Ok(queue)
98    }
99
100    pub fn validate(&self) -> Result<(), SupplyQueueError> {
101        let actual_len = self.buckets.iter().try_fold(0u32, |acc, bucket| {
102            let bucket_len =
103                u32::try_from(bucket.len()).map_err(|_| SupplyQueueError::LengthOverflow)?;
104            acc.checked_add(bucket_len)
105                .ok_or(SupplyQueueError::LengthOverflow)
106        })?;
107
108        if self.len != actual_len {
109            return Err(SupplyQueueError::LengthMismatch {
110                recorded_len: self.len,
111                actual_len,
112            });
113        }
114
115        if let Some(max_length) = self.max_length {
116            if self.len > max_length {
117                return Err(SupplyQueueError::QueueTooLong {
118                    len: self.len,
119                    max_length,
120                });
121            }
122        }
123
124        for (priority, bucket) in self.buckets.iter().enumerate() {
125            let expected_priority =
126                u8::try_from(priority).map_err(|_| SupplyQueueError::LengthOverflow)?;
127            for entry in bucket {
128                entry.validate()?;
129                if entry.priority != expected_priority {
130                    return Err(SupplyQueueError::PriorityBucketMismatch {
131                        expected_priority,
132                        actual_priority: entry.priority,
133                    });
134                }
135            }
136        }
137
138        Ok(())
139    }
140
141    #[must_use]
142    pub fn is_empty(&self) -> bool {
143        self.len == 0
144    }
145
146    #[must_use]
147    pub fn len(&self) -> usize {
148        match usize::try_from(self.len) {
149            Ok(len) => len,
150            Err(_) => unreachable!("u32 supply queue length must fit usize"),
151        }
152    }
153
154    #[must_use]
155    pub fn is_full(&self) -> bool {
156        self.max_length
157            .is_some_and(|max_length| self.len >= max_length)
158    }
159
160    #[must_use]
161    pub fn entries(&self) -> Vec<&SupplyQueueEntry> {
162        self.buckets
163            .iter()
164            .rev()
165            .flat_map(|bucket| bucket.iter())
166            .collect()
167    }
168
169    #[must_use]
170    pub fn max_length(&self) -> Option<NonZeroU32> {
171        self.max_length.and_then(NonZeroU32::new)
172    }
173
174    pub fn enqueue(&mut self, entry: SupplyQueueEntry) -> Result<(), SupplyQueueError> {
175        entry.validate()?;
176
177        if self.is_full() {
178            let Some(max_length) = self.max_length else {
179                unreachable!("is_full() guarantees max_length is Some");
180            };
181            return Err(SupplyQueueError::QueueFull { max_length });
182        }
183
184        self.push_validated_entry(entry)
185            .ok_or(SupplyQueueError::LengthOverflow)?;
186        Ok(())
187    }
188
189    fn push_validated_entry(&mut self, entry: SupplyQueueEntry) -> Option<()> {
190        self.buckets[usize::from(entry.priority)].push(entry);
191        self.len = self.len.checked_add(1)?;
192        Some(())
193    }
194
195    pub fn dequeue(&mut self) -> Result<SupplyQueueEntry, SupplyQueueError> {
196        for bucket in self.buckets.iter_mut().rev() {
197            if !bucket.is_empty() {
198                let entry = bucket.remove(0);
199                self.len = self.len.saturating_sub(1);
200                return Ok(entry);
201            }
202        }
203
204        Err(SupplyQueueError::QueueEmpty)
205    }
206
207    #[must_use]
208    pub fn peek(&self) -> Option<&SupplyQueueEntry> {
209        self.buckets.iter().rev().find_map(|bucket| bucket.first())
210    }
211
212    pub fn total(&self) -> Result<u128, SupplyQueueError> {
213        checked_total_amount(self.entries().into_iter().map(|entry| entry.amount))
214    }
215
216    pub fn totals_by_target(&self) -> Result<Vec<(TargetId, u128)>, SupplyQueueError> {
217        let mut totals: Vec<(TargetId, u128)> = Vec::new();
218        for entry in self.entries() {
219            let sum = match totals
220                .iter_mut()
221                .find(|(target_id, _)| *target_id == entry.target_id)
222            {
223                Some((_, total)) => total,
224                None => {
225                    let index = totals.len();
226                    totals.push((entry.target_id, 0));
227                    &mut totals[index].1
228                }
229            };
230            *sum = (*sum)
231                .checked_add(entry.amount)
232                .ok_or(SupplyQueueError::AmountOverflow)?;
233        }
234        Ok(totals)
235    }
236
237    pub fn remove_target(&mut self, target_id: TargetId) {
238        let mut removed = 0u32;
239        for bucket in &mut self.buckets {
240            let before = bucket.len();
241            bucket.retain(|entry| entry.target_id != target_id);
242            let after = bucket.len();
243            let diff = before.saturating_sub(after);
244            removed = removed.saturating_add(u32::try_from(diff).unwrap_or(u32::MAX));
245        }
246        self.len = self.len.saturating_sub(removed);
247    }
248
249    #[must_use]
250    pub fn excluding_leased(&self, leases: &MarketLeaseRegistry, now_ns: TimestampNs) -> Self {
251        let mut filtered = Self::new(self.max_length());
252        for entry in self.entries() {
253            if leases.is_unleased(entry.target_id, now_ns) {
254                let inserted = filtered.push_validated_entry(entry.clone());
255                debug_assert!(inserted.is_some());
256            }
257        }
258        filtered
259    }
260
261    pub fn drain(&mut self) -> Vec<SupplyQueueEntry> {
262        let mut drained = Vec::with_capacity(self.len());
263        for bucket in self.buckets.iter_mut().rev() {
264            drained.append(bucket);
265        }
266        self.len = 0;
267        drained
268    }
269
270    pub fn to_allocation_plan(&self) -> Result<Vec<(TargetId, u128)>, SupplyQueueError> {
271        let mut totals = self.totals_by_target()?;
272        let mut plan = Vec::with_capacity(totals.len());
273
274        for entry in self.entries() {
275            if let Some(index) = totals
276                .iter()
277                .position(|(target_id, _)| *target_id == entry.target_id)
278            {
279                let (_, amount) = totals.remove(index);
280                plan.push((entry.target_id, amount));
281            }
282        }
283
284        Ok(plan)
285    }
286
287    pub fn to_allocation_plan_excluding_leased(
288        &self,
289        leases: &MarketLeaseRegistry,
290        now_ns: TimestampNs,
291    ) -> Result<Vec<(TargetId, u128)>, SupplyQueueError> {
292        self.excluding_leased(leases, now_ns).to_allocation_plan()
293    }
294
295    pub fn total_for_target(&self, target_id: TargetId) -> Result<u128, SupplyQueueError> {
296        self.entries()
297            .into_iter()
298            .filter(|entry| entry.target_id == target_id)
299            .map(|entry| entry.amount)
300            .try_fold(0u128, |acc, amount| {
301                acc.checked_add(amount)
302                    .ok_or(SupplyQueueError::AmountOverflow)
303            })
304    }
305
306    #[must_use]
307    pub fn has_target(&self, target_id: TargetId) -> bool {
308        self.entries()
309            .into_iter()
310            .any(|entry| entry.target_id == target_id)
311    }
312}
313
314#[templar_vault_macros::vault_derive]
315#[derive(Clone, PartialEq, Eq)]
316pub enum SupplyQueueError {
317    QueueFull {
318        max_length: u32,
319    },
320    QueueTooLong {
321        len: u32,
322        max_length: u32,
323    },
324    ZeroAmount,
325    PriorityBucketMismatch {
326        expected_priority: u8,
327        actual_priority: u8,
328    },
329    LengthMismatch {
330        recorded_len: u32,
331        actual_len: u32,
332    },
333    LengthOverflow,
334    AmountOverflow,
335    QueueEmpty,
336}
337
338fn checked_total_amount<I>(amounts: I) -> Result<u128, SupplyQueueError>
339where
340    I: IntoIterator<Item = u128>,
341{
342    amounts.into_iter().try_fold(0u128, |acc, amount| {
343        acc.checked_add(amount)
344            .ok_or(SupplyQueueError::AmountOverflow)
345    })
346}