1use std::ops::Deref;
2
3use near_sdk::{near, require};
4
5use crate::number::Decimal;
6
7pub trait UsageCurve {
8 fn at(&self, usage_ratio: Decimal) -> Decimal;
9}
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12#[near(serializers = [json, borsh])]
13pub enum InterestRateStrategy {
14 Linear(Linear),
15 Piecewise(Piecewise),
16 Exponential2(Exponential2),
17}
18
19impl InterestRateStrategy {
20 pub const fn zero() -> Self {
21 Self::Linear(Linear {
22 base: Decimal::ZERO,
23 top: Decimal::ZERO,
24 })
25 }
26
27 #[must_use]
28 pub fn linear(base: Decimal, top: Decimal) -> Option<Self> {
29 Some(Self::Linear(Linear::new(base, top)?))
30 }
31
32 #[must_use]
33 pub fn piecewise(
34 base: Decimal,
35 optimal: Decimal,
36 rate_1: Decimal,
37 rate_2: Decimal,
38 ) -> Option<Self> {
39 Some(Self::Piecewise(Piecewise::new(
40 base, optimal, rate_1, rate_2,
41 )?))
42 }
43
44 #[must_use]
45 pub fn exponential2(base: Decimal, top: Decimal, eccentricity: Decimal) -> Option<Self> {
46 Some(Self::Exponential2(Exponential2::new(
47 base,
48 top,
49 eccentricity,
50 )?))
51 }
52}
53
54impl Deref for InterestRateStrategy {
55 type Target = dyn UsageCurve;
56
57 fn deref(&self) -> &Self::Target {
58 match self {
59 Self::Linear(linear) => linear as &dyn UsageCurve,
60 Self::Piecewise(piecewise) => piecewise as &dyn UsageCurve,
61 Self::Exponential2(exponential2) => exponential2 as &dyn UsageCurve,
62 }
63 }
64}
65
66#[derive(Debug, Clone, PartialEq, Eq)]
70#[near(serializers = [borsh, json])]
71pub struct Linear {
72 base: Decimal,
73 top: Decimal,
74}
75
76impl Linear {
77 pub fn new(base: Decimal, top: Decimal) -> Option<Self> {
78 (base <= top).then_some(Self { base, top })
79 }
80}
81
82impl UsageCurve for Linear {
83 fn at(&self, usage_ratio: Decimal) -> Decimal {
84 usage_ratio * (self.top - self.base) + self.base
85 }
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
95#[near(serializers = [borsh, json])]
96#[serde(try_from = "PiecewiseParams", into = "PiecewiseParams")]
97pub struct Piecewise {
98 params: PiecewiseParams,
99 i_negative_rate_2_b: Decimal,
100}
101
102impl Piecewise {
103 pub fn new(base: Decimal, optimal: Decimal, rate_1: Decimal, rate_2: Decimal) -> Option<Self> {
104 if optimal > 1u32 {
105 return None;
106 }
107
108 if rate_1 > rate_2 {
109 return None;
110 }
111
112 Some(Self {
113 i_negative_rate_2_b: optimal * (rate_2 - rate_1) - base,
114 params: PiecewiseParams {
115 base,
116 optimal,
117 rate_1,
118 rate_2,
119 },
120 })
121 }
122}
123
124impl UsageCurve for Piecewise {
125 fn at(&self, usage_ratio: Decimal) -> Decimal {
126 require!(
127 usage_ratio <= Decimal::ONE,
128 "Invariant violation: Usage ratio cannot be over 100%.",
129 );
130
131 if usage_ratio < self.params.optimal {
132 self.params.rate_1 * usage_ratio + self.params.base
133 } else {
134 self.params.rate_2 * usage_ratio - self.i_negative_rate_2_b
135 }
136 }
137}
138
139#[derive(Debug, Clone, PartialEq, Eq)]
140#[near(serializers = [json, borsh])]
141pub struct PiecewiseParams {
142 base: Decimal,
143 optimal: Decimal,
144 rate_1: Decimal,
145 rate_2: Decimal,
146}
147
148impl TryFrom<PiecewiseParams> for Piecewise {
149 type Error = &'static str;
150
151 fn try_from(
152 PiecewiseParams {
153 base,
154 optimal,
155 rate_1,
156 rate_2,
157 }: PiecewiseParams,
158 ) -> Result<Self, Self::Error> {
159 Self::new(base, optimal, rate_1, rate_2).ok_or("Invalid Piecewise parameters")
160 }
161}
162
163impl From<Piecewise> for PiecewiseParams {
164 fn from(value: Piecewise) -> Self {
165 value.params
166 }
167}
168
169#[derive(Debug, Clone, PartialEq, Eq)]
173#[near(serializers = [borsh, json])]
174#[serde(try_from = "Exponential2Params", into = "Exponential2Params")]
175pub struct Exponential2 {
176 params: Exponential2Params,
177 i_factor: Decimal,
178}
179
180impl Exponential2 {
181 pub fn new(base: Decimal, top: Decimal, eccentricity: Decimal) -> Option<Self> {
184 if base > top {
185 return None;
186 }
187
188 if eccentricity > 24u32 || eccentricity.is_zero() {
189 return None;
190 }
191
192 #[allow(clippy::unwrap_used, reason = "Invariant checked above")]
193 Some(Self {
194 i_factor: (top - base) / (eccentricity.pow2().unwrap() - 1u32),
195 params: Exponential2Params {
196 base,
197 top,
198 eccentricity,
199 },
200 })
201 }
202}
203
204impl UsageCurve for Exponential2 {
205 fn at(&self, usage_ratio: Decimal) -> Decimal {
206 require!(
207 usage_ratio <= Decimal::ONE,
208 "Invariant violation: Usage ratio cannot be over 100%.",
209 );
210
211 #[allow(clippy::unwrap_used, reason = "Invariant checked above")]
212 (self.params.base
213 + self.i_factor * ((self.params.eccentricity * usage_ratio).pow2().unwrap() - 1u32))
214 }
215}
216
217#[derive(Debug, Clone, PartialEq, Eq)]
218#[near(serializers = [json, borsh])]
219pub struct Exponential2Params {
220 base: Decimal,
221 top: Decimal,
222 eccentricity: Decimal,
223}
224
225impl TryFrom<Exponential2Params> for Exponential2 {
226 type Error = &'static str;
227
228 fn try_from(
229 Exponential2Params {
230 base,
231 top,
232 eccentricity,
233 }: Exponential2Params,
234 ) -> Result<Self, Self::Error> {
235 Self::new(base, top, eccentricity).ok_or("Invalid Exponential2 parameters")
236 }
237}
238
239impl From<Exponential2> for Exponential2Params {
240 fn from(value: Exponential2) -> Self {
241 value.params
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use std::ops::Div;
248
249 use crate::dec;
250
251 use super::*;
252
253 #[test]
254 fn piecewise() {
255 let s = Piecewise::new(Decimal::ZERO, dec!("0.9"), dec!("0.035"), dec!("0.6")).unwrap();
256
257 assert!(s.at(Decimal::ZERO).near_equal(Decimal::ZERO));
258 assert!(s.at(dec!("0.1")).near_equal(dec!("0.0035")));
259 assert!(s.at(dec!("0.5")).near_equal(dec!("0.0175")));
260 assert!(s.at(dec!("0.6")).near_equal(dec!("0.021")));
261 assert!(s.at(dec!("0.9")).near_equal(dec!("0.0315")));
262 assert!(s.at(dec!("0.95")).near_equal(dec!("0.0615")));
263 assert!(s.at(Decimal::ONE).near_equal(dec!("0.0915")));
264 }
265
266 #[test]
267 fn exponential2() {
268 let s = Exponential2::new(dec!("0.005"), dec!("0.08"), dec!("6")).unwrap();
269 assert!(s.at(Decimal::ZERO).near_equal(dec!("0.005")));
270 assert!(s.at(dec!("0.25")).near_equal(dec!(
271 "0.00717669895803117868762306839097547161564207589375463826946828509045412494"
272 )));
273 assert!(s.at(Decimal::ONE_HALF).near_equal(Decimal::ONE.div(75u32)));
274 }
275}