templar_curator_primitives/policy/supply_queue/
mod.rs1use alloc::vec::Vec;
4use core::num::NonZeroU32;
5use templar_vault_kernel::{TargetId, TimestampNs};
6
7use super::market_lock::MarketLeaseRegistry;
8
9#[templar_vault_macros::vault_derive(borsh, serde)]
10#[derive(Clone, PartialEq, Eq)]
11pub struct SupplyQueueEntry {
12 pub target_id: TargetId,
13 pub amount: u128,
14 pub priority: u8,
15}
16
17impl SupplyQueueEntry {
18 pub fn new(target_id: TargetId, amount: u128) -> Result<Self, SupplyQueueError> {
19 Self::new_with_priority(target_id, amount, 0)
20 }
21
22 pub fn new_with_priority(
23 target_id: TargetId,
24 amount: u128,
25 priority: u8,
26 ) -> Result<Self, SupplyQueueError> {
27 if amount == 0 {
28 return Err(SupplyQueueError::ZeroAmount);
29 }
30
31 Ok(Self {
32 target_id,
33 amount,
34 priority,
35 })
36 }
37
38 fn validate(&self) -> Result<(), SupplyQueueError> {
39 if self.amount == 0 {
40 return Err(SupplyQueueError::ZeroAmount);
41 }
42
43 Ok(())
44 }
45}
46
47impl TryFrom<(TargetId, u128)> for SupplyQueueEntry {
48 type Error = SupplyQueueError;
49
50 fn try_from(value: (TargetId, u128)) -> Result<Self, Self::Error> {
51 Self::new(value.0, value.1)
52 }
53}
54
55#[templar_vault_macros::vault_derive(borsh, serde)]
56#[derive(Clone, PartialEq, Eq)]
57pub struct SupplyQueue {
58 buckets: Vec<Vec<SupplyQueueEntry>>,
59 len: u32,
60 max_length: Option<u32>,
61}
62
63impl Default for SupplyQueue {
64 fn default() -> Self {
65 Self::unbounded()
66 }
67}
68
69impl SupplyQueue {
70 #[must_use]
71 pub fn new(max_length: Option<NonZeroU32>) -> Self {
72 Self {
73 buckets: alloc::vec![Vec::new(); usize::from(u8::MAX) + 1],
74 len: 0,
75 max_length: max_length.map(NonZeroU32::get),
76 }
77 }
78
79 #[must_use]
80 pub fn unbounded() -> Self {
81 Self::new(None)
82 }
83
84 #[must_use]
85 pub fn bounded(max_length: NonZeroU32) -> Self {
86 Self::new(Some(max_length))
87 }
88
89 pub fn try_from_entries(
90 entries: Vec<SupplyQueueEntry>,
91 max_length: Option<NonZeroU32>,
92 ) -> Result<Self, SupplyQueueError> {
93 let mut queue = Self::new(max_length);
94 for entry in entries {
95 queue.enqueue(entry)?;
96 }
97 Ok(queue)
98 }
99
100 pub fn validate(&self) -> Result<(), SupplyQueueError> {
101 let actual_len = self.buckets.iter().try_fold(0u32, |acc, bucket| {
102 let bucket_len =
103 u32::try_from(bucket.len()).map_err(|_| SupplyQueueError::LengthOverflow)?;
104 acc.checked_add(bucket_len)
105 .ok_or(SupplyQueueError::LengthOverflow)
106 })?;
107
108 if self.len != actual_len {
109 return Err(SupplyQueueError::LengthMismatch {
110 recorded_len: self.len,
111 actual_len,
112 });
113 }
114
115 if let Some(max_length) = self.max_length {
116 if self.len > max_length {
117 return Err(SupplyQueueError::QueueTooLong {
118 len: self.len,
119 max_length,
120 });
121 }
122 }
123
124 for (priority, bucket) in self.buckets.iter().enumerate() {
125 let expected_priority =
126 u8::try_from(priority).map_err(|_| SupplyQueueError::LengthOverflow)?;
127 for entry in bucket {
128 entry.validate()?;
129 if entry.priority != expected_priority {
130 return Err(SupplyQueueError::PriorityBucketMismatch {
131 expected_priority,
132 actual_priority: entry.priority,
133 });
134 }
135 }
136 }
137
138 Ok(())
139 }
140
141 #[must_use]
142 pub fn is_empty(&self) -> bool {
143 self.len == 0
144 }
145
146 #[must_use]
147 pub fn len(&self) -> usize {
148 match usize::try_from(self.len) {
149 Ok(len) => len,
150 Err(_) => unreachable!("u32 supply queue length must fit usize"),
151 }
152 }
153
154 #[must_use]
155 pub fn is_full(&self) -> bool {
156 self.max_length
157 .is_some_and(|max_length| self.len >= max_length)
158 }
159
160 #[must_use]
161 pub fn entries(&self) -> Vec<&SupplyQueueEntry> {
162 self.buckets
163 .iter()
164 .rev()
165 .flat_map(|bucket| bucket.iter())
166 .collect()
167 }
168
169 #[must_use]
170 pub fn max_length(&self) -> Option<NonZeroU32> {
171 self.max_length.and_then(NonZeroU32::new)
172 }
173
174 pub fn enqueue(&mut self, entry: SupplyQueueEntry) -> Result<(), SupplyQueueError> {
175 entry.validate()?;
176
177 if self.is_full() {
178 let Some(max_length) = self.max_length else {
179 unreachable!("is_full() guarantees max_length is Some");
180 };
181 return Err(SupplyQueueError::QueueFull { max_length });
182 }
183
184 self.push_validated_entry(entry)
185 .ok_or(SupplyQueueError::LengthOverflow)?;
186 Ok(())
187 }
188
189 fn push_validated_entry(&mut self, entry: SupplyQueueEntry) -> Option<()> {
190 self.buckets[usize::from(entry.priority)].push(entry);
191 self.len = self.len.checked_add(1)?;
192 Some(())
193 }
194
195 pub fn dequeue(&mut self) -> Result<SupplyQueueEntry, SupplyQueueError> {
196 for bucket in self.buckets.iter_mut().rev() {
197 if !bucket.is_empty() {
198 let entry = bucket.remove(0);
199 self.len = self.len.saturating_sub(1);
200 return Ok(entry);
201 }
202 }
203
204 Err(SupplyQueueError::QueueEmpty)
205 }
206
207 #[must_use]
208 pub fn peek(&self) -> Option<&SupplyQueueEntry> {
209 self.buckets.iter().rev().find_map(|bucket| bucket.first())
210 }
211
212 pub fn total(&self) -> Result<u128, SupplyQueueError> {
213 checked_total_amount(self.entries().into_iter().map(|entry| entry.amount))
214 }
215
216 pub fn totals_by_target(&self) -> Result<Vec<(TargetId, u128)>, SupplyQueueError> {
217 let mut totals: Vec<(TargetId, u128)> = Vec::new();
218 for entry in self.entries() {
219 let sum = match totals
220 .iter_mut()
221 .find(|(target_id, _)| *target_id == entry.target_id)
222 {
223 Some((_, total)) => total,
224 None => {
225 let index = totals.len();
226 totals.push((entry.target_id, 0));
227 &mut totals[index].1
228 }
229 };
230 *sum = (*sum)
231 .checked_add(entry.amount)
232 .ok_or(SupplyQueueError::AmountOverflow)?;
233 }
234 Ok(totals)
235 }
236
237 pub fn remove_target(&mut self, target_id: TargetId) {
238 let mut removed = 0u32;
239 for bucket in &mut self.buckets {
240 let before = bucket.len();
241 bucket.retain(|entry| entry.target_id != target_id);
242 let after = bucket.len();
243 let diff = before.saturating_sub(after);
244 removed = removed.saturating_add(u32::try_from(diff).unwrap_or(u32::MAX));
245 }
246 self.len = self.len.saturating_sub(removed);
247 }
248
249 #[must_use]
250 pub fn excluding_leased(&self, leases: &MarketLeaseRegistry, now_ns: TimestampNs) -> Self {
251 let mut filtered = Self::new(self.max_length());
252 for entry in self.entries() {
253 if leases.is_unleased(entry.target_id, now_ns) {
254 let inserted = filtered.push_validated_entry(entry.clone());
255 debug_assert!(inserted.is_some());
256 }
257 }
258 filtered
259 }
260
261 pub fn drain(&mut self) -> Vec<SupplyQueueEntry> {
262 let mut drained = Vec::with_capacity(self.len());
263 for bucket in self.buckets.iter_mut().rev() {
264 drained.append(bucket);
265 }
266 self.len = 0;
267 drained
268 }
269
270 pub fn to_allocation_plan(&self) -> Result<Vec<(TargetId, u128)>, SupplyQueueError> {
271 let mut totals = self.totals_by_target()?;
272 let mut plan = Vec::with_capacity(totals.len());
273
274 for entry in self.entries() {
275 if let Some(index) = totals
276 .iter()
277 .position(|(target_id, _)| *target_id == entry.target_id)
278 {
279 let (_, amount) = totals.remove(index);
280 plan.push((entry.target_id, amount));
281 }
282 }
283
284 Ok(plan)
285 }
286
287 pub fn to_allocation_plan_excluding_leased(
288 &self,
289 leases: &MarketLeaseRegistry,
290 now_ns: TimestampNs,
291 ) -> Result<Vec<(TargetId, u128)>, SupplyQueueError> {
292 self.excluding_leased(leases, now_ns).to_allocation_plan()
293 }
294
295 pub fn total_for_target(&self, target_id: TargetId) -> Result<u128, SupplyQueueError> {
296 self.entries()
297 .into_iter()
298 .filter(|entry| entry.target_id == target_id)
299 .map(|entry| entry.amount)
300 .try_fold(0u128, |acc, amount| {
301 acc.checked_add(amount)
302 .ok_or(SupplyQueueError::AmountOverflow)
303 })
304 }
305
306 #[must_use]
307 pub fn has_target(&self, target_id: TargetId) -> bool {
308 self.entries()
309 .into_iter()
310 .any(|entry| entry.target_id == target_id)
311 }
312}
313
314#[templar_vault_macros::vault_derive]
315#[derive(Clone, PartialEq, Eq)]
316pub enum SupplyQueueError {
317 QueueFull {
318 max_length: u32,
319 },
320 QueueTooLong {
321 len: u32,
322 max_length: u32,
323 },
324 ZeroAmount,
325 PriorityBucketMismatch {
326 expected_priority: u8,
327 actual_priority: u8,
328 },
329 LengthMismatch {
330 recorded_len: u32,
331 actual_len: u32,
332 },
333 LengthOverflow,
334 AmountOverflow,
335 QueueEmpty,
336}
337
338fn checked_total_amount<I>(amounts: I) -> Result<u128, SupplyQueueError>
339where
340 I: IntoIterator<Item = u128>,
341{
342 amounts.into_iter().try_fold(0u128, |acc, amount| {
343 acc.checked_add(amount)
344 .ok_or(SupplyQueueError::AmountOverflow)
345 })
346}