kabu_defi_uniswap_v3_math/
sqrt_price_math.rs

1use alloy::primitives::{I256, U256};
2
3use crate::{
4    error::UniswapV3MathError,
5    full_math::{mul_div, mul_div_rounding_up},
6    unsafe_math::div_rounding_up,
7};
8
9pub const MAX_U160: U256 = U256::from_limbs([18446744073709551615, 18446744073709551615, 4294967295, 0]);
10pub const Q96: U256 = U256::from_limbs([0, 4294967296, 0, 0]);
11pub const FIXED_POINT_96_RESOLUTION: U256 = U256::from_limbs([96, 0, 0, 0]);
12
13// returns (sqrtQX96)
14pub fn get_next_sqrt_price_from_input(
15    sqrt_price: U256,
16    liquidity: u128,
17    amount_in: U256,
18    zero_for_one: bool,
19) -> Result<U256, UniswapV3MathError> {
20    if sqrt_price.is_zero() {
21        return Err(UniswapV3MathError::SqrtPriceIsZero);
22    } else if liquidity == 0 {
23        return Err(UniswapV3MathError::LiquidityIsZero);
24    }
25
26    if zero_for_one {
27        get_next_sqrt_price_from_amount_0_rounding_up(sqrt_price, liquidity, amount_in, true)
28    } else {
29        get_next_sqrt_price_from_amount_1_rounding_down(sqrt_price, liquidity, amount_in, true)
30    }
31}
32
33// returns (sqrtQX96)
34pub fn get_next_sqrt_price_from_output(
35    sqrt_price: U256,
36    liquidity: u128,
37    amount_out: U256,
38    zero_for_one: bool,
39) -> Result<U256, UniswapV3MathError> {
40    if sqrt_price.is_zero() {
41        return Err(UniswapV3MathError::SqrtPriceIsZero);
42    } else if liquidity == 0 {
43        return Err(UniswapV3MathError::LiquidityIsZero);
44    }
45
46    if zero_for_one {
47        get_next_sqrt_price_from_amount_1_rounding_down(sqrt_price, liquidity, amount_out, false)
48    } else {
49        get_next_sqrt_price_from_amount_0_rounding_up(sqrt_price, liquidity, amount_out, false)
50    }
51}
52
53// returns (uint160 sqrtQX96)
54pub fn get_next_sqrt_price_from_amount_0_rounding_up(
55    sqrt_price_x_96: U256,
56    liquidity: u128,
57    amount: U256,
58    add: bool,
59) -> Result<U256, UniswapV3MathError> {
60    if amount.is_zero() {
61        return Ok(sqrt_price_x_96);
62    }
63
64    let numerator_1: U256 = U256::from(liquidity) << 96;
65
66    if add {
67        let product = amount.wrapping_mul(sqrt_price_x_96);
68
69        if product.wrapping_div(amount) == sqrt_price_x_96 {
70            let denominator = numerator_1.wrapping_add(product);
71
72            if denominator >= numerator_1 {
73                return mul_div_rounding_up(numerator_1, sqrt_price_x_96, denominator);
74            }
75        }
76
77        Ok(div_rounding_up(numerator_1, (numerator_1.wrapping_div(sqrt_price_x_96)).wrapping_add(amount)))
78    } else {
79        let product = amount.wrapping_mul(sqrt_price_x_96);
80        if product.wrapping_div(amount) == sqrt_price_x_96 && numerator_1 > product {
81            let denominator = numerator_1.wrapping_sub(product);
82
83            mul_div_rounding_up(numerator_1, sqrt_price_x_96, denominator)
84        } else {
85            Err(UniswapV3MathError::ProductDivAmount)
86        }
87    }
88}
89
90// returns (uint160 sqrtQX96)
91pub fn get_next_sqrt_price_from_amount_1_rounding_down(
92    sqrt_price_x_96: U256,
93    liquidity: u128,
94    amount: U256,
95    add: bool,
96) -> Result<U256, UniswapV3MathError> {
97    let liquidity = U256::from(liquidity);
98
99    if add {
100        let quotient =
101            if amount <= MAX_U160 { (amount << FIXED_POINT_96_RESOLUTION) / liquidity } else { mul_div(amount, Q96, liquidity)? };
102
103        let next_sqrt_price = sqrt_price_x_96 + quotient;
104
105        if next_sqrt_price > MAX_U160 {
106            Err(UniswapV3MathError::SafeCastToU160Overflow)
107        } else {
108            Ok(next_sqrt_price)
109        }
110    } else {
111        let quotient = if amount <= MAX_U160 {
112            div_rounding_up(amount << FIXED_POINT_96_RESOLUTION, liquidity)
113        } else {
114            mul_div_rounding_up(amount, Q96, liquidity)?
115        };
116
117        //require(sqrtPX96 > quotient);
118        if sqrt_price_x_96 <= quotient {
119            return Err(UniswapV3MathError::SqrtPriceIsLteQuotient);
120        }
121
122        Ok(sqrt_price_x_96.overflowing_sub(quotient).0)
123    }
124}
125
126// returns (uint256 amount0)
127pub fn _get_amount_0_delta(
128    mut sqrt_ratio_a_x_96: U256,
129    mut sqrt_ratio_b_x_96: U256,
130    liquidity: u128,
131    round_up: bool,
132) -> Result<U256, UniswapV3MathError> {
133    if sqrt_ratio_a_x_96 > sqrt_ratio_b_x_96 {
134        (sqrt_ratio_a_x_96, sqrt_ratio_b_x_96) = (sqrt_ratio_b_x_96, sqrt_ratio_a_x_96)
135    };
136
137    let numerator_1 = U256::from(liquidity) << 96;
138    let numerator_2 = sqrt_ratio_b_x_96 - sqrt_ratio_a_x_96;
139
140    if sqrt_ratio_a_x_96.is_zero() {
141        return Err(UniswapV3MathError::SqrtPriceIsZero);
142    }
143
144    if round_up {
145        let numerator_partial = mul_div_rounding_up(numerator_1, numerator_2, sqrt_ratio_b_x_96)?;
146        Ok(div_rounding_up(numerator_partial, sqrt_ratio_a_x_96))
147    } else {
148        Ok(mul_div(numerator_1, numerator_2, sqrt_ratio_b_x_96)? / sqrt_ratio_a_x_96)
149    }
150}
151
152// returns (uint256 amount1)
153pub fn _get_amount_1_delta(
154    mut sqrt_ratio_a_x_96: U256,
155    mut sqrt_ratio_b_x_96: U256,
156    liquidity: u128,
157    round_up: bool,
158) -> Result<U256, UniswapV3MathError> {
159    if sqrt_ratio_a_x_96 > sqrt_ratio_b_x_96 {
160        (sqrt_ratio_a_x_96, sqrt_ratio_b_x_96) = (sqrt_ratio_b_x_96, sqrt_ratio_a_x_96)
161    };
162
163    let denominator = U256::from_limbs([0, 4294967296, 0, 0]);
164
165    if round_up {
166        mul_div_rounding_up(U256::from(liquidity), sqrt_ratio_b_x_96 - sqrt_ratio_a_x_96, denominator)
167    } else {
168        mul_div(U256::from(liquidity), sqrt_ratio_b_x_96 - sqrt_ratio_a_x_96, denominator)
169    }
170}
171
172pub fn get_amount_0_delta(sqrt_ratio_a_x_96: U256, sqrt_ratio_b_x_96: U256, liquidity: i128) -> Result<I256, UniswapV3MathError> {
173    if liquidity < 0 {
174        Ok(-I256::from_raw(_get_amount_0_delta(sqrt_ratio_a_x_96, sqrt_ratio_b_x_96, -liquidity as u128, false)?))
175    } else {
176        Ok(I256::from_raw(_get_amount_0_delta(sqrt_ratio_a_x_96, sqrt_ratio_b_x_96, liquidity as u128, true)?))
177    }
178}
179
180pub fn get_amount_1_delta(sqrt_ratio_a_x_96: U256, sqrt_ratio_b_x_96: U256, liquidity: i128) -> Result<I256, UniswapV3MathError> {
181    if liquidity < 0 {
182        Ok(-I256::from_raw(_get_amount_1_delta(sqrt_ratio_a_x_96, sqrt_ratio_b_x_96, -liquidity as u128, false)?))
183    } else {
184        Ok(I256::from_raw(_get_amount_1_delta(sqrt_ratio_a_x_96, sqrt_ratio_b_x_96, liquidity as u128, true)?))
185    }
186}
187
188#[cfg(test)]
189mod test {
190    use std::{
191        ops::{Add, Sub},
192        str::FromStr,
193    };
194
195    use alloy::primitives::U256;
196
197    use crate::{
198        sqrt_price_math::{_get_amount_1_delta, get_next_sqrt_price_from_output, MAX_U160},
199        U256_1, U256_2,
200    };
201
202    use super::{_get_amount_0_delta, get_next_sqrt_price_from_input};
203
204    #[test]
205    fn test_get_next_sqrt_price_from_input() {
206        //Fails if price is zero
207        let result = get_next_sqrt_price_from_input(U256::ZERO, 0, U256::from(100000000000000000_u128), false);
208        assert_eq!(result.unwrap_err().to_string(), "Sqrt price is 0");
209
210        //Fails if liquidity is zero
211        let result = get_next_sqrt_price_from_input(U256_1, 0, U256::from(100000000000000000_u128), true);
212        assert_eq!(result.unwrap_err().to_string(), "Liquidity is 0");
213
214        //fails if input amount overflows the price
215        let result = get_next_sqrt_price_from_input(MAX_U160, 1024, U256::from(1024), false);
216        assert_eq!(result.unwrap_err().to_string(), "Overflow when casting to U160");
217
218        //any input amount cannot underflow the price
219        let result = get_next_sqrt_price_from_input(
220            U256_1,
221            1,
222            U256::from_str("57896044618658097711785492504343953926634992332820282019728792003956564819968").unwrap(),
223            true,
224        );
225
226        assert_eq!(result.unwrap(), U256_1);
227
228        //returns input price if amount in is zero and zeroForOne = true
229        let result =
230            get_next_sqrt_price_from_input(U256::from_str("79228162514264337593543950336").unwrap(), 1e17 as u128, U256::ZERO, true);
231
232        assert_eq!(result.unwrap(), U256::from_str("79228162514264337593543950336").unwrap());
233
234        //returns input price if amount in is zero and zeroForOne = false
235        let result =
236            get_next_sqrt_price_from_input(U256::from_str("79228162514264337593543950336").unwrap(), 1e17 as u128, U256::ZERO, true);
237
238        assert_eq!(result.unwrap(), U256::from_str("79228162514264337593543950336").unwrap());
239
240        //returns the minimum price for max inputs
241
242        let sqrt_price = MAX_U160;
243        let liquidity = u128::MAX;
244        let max_amount_no_overflow = U256::MAX - ((U256::from(liquidity) << 96) / sqrt_price);
245        let result = get_next_sqrt_price_from_input(sqrt_price, liquidity, max_amount_no_overflow, true);
246        assert_eq!(result.unwrap(), U256_1);
247
248        //input amount of 0.1 token1
249        let result = get_next_sqrt_price_from_input(
250            U256::from_str("79228162514264337593543950336").unwrap(),
251            1e18 as u128,
252            U256::from_str("100000000000000000").unwrap(),
253            false,
254        );
255
256        assert_eq!(result.unwrap(), U256::from_str("87150978765690771352898345369").unwrap());
257
258        //input amount of 0.1 token0
259        let result = get_next_sqrt_price_from_input(
260            U256::from_str("79228162514264337593543950336").unwrap(),
261            1e18 as u128,
262            U256::from_str("100000000000000000").unwrap(),
263            true,
264        );
265
266        assert_eq!(result.unwrap(), U256::from_str("72025602285694852357767227579").unwrap());
267
268        //amountIn > type(uint96).max and zeroForOne = true
269        let result = get_next_sqrt_price_from_input(
270            U256::from_str("79228162514264337593543950336").unwrap(),
271            1e19 as u128,
272            U256::from_str("1267650600228229401496703205376").unwrap(),
273            true,
274        );
275        // perfect answer:
276        // https://www.wolframalpha.com/input/?i=624999999995069620+-+%28%281e19+*+1+%2F+%281e19+%2B+2%5E100+*+1%29%29+*+2%5E96%29
277        assert_eq!(result.unwrap(), U256::from_str("624999999995069620").unwrap());
278
279        //can return 1 with enough amountIn and zeroForOne = true
280        let result = get_next_sqrt_price_from_input(U256::from_str("79228162514264337593543950336").unwrap(), 1, U256::MAX / U256_2, true);
281
282        assert_eq!(result.unwrap(), U256_1);
283    }
284
285    #[test]
286    fn test_get_next_sqrt_price_from_output() {
287        //fails if price is zero
288        let result = get_next_sqrt_price_from_output(U256::ZERO, 0, U256::from(1000000000), false);
289        assert_eq!(result.unwrap_err().to_string(), "Sqrt price is 0");
290
291        //fails if liquidity is zero
292        let result = get_next_sqrt_price_from_output(U256_1, 0, U256::from(1000000000), false);
293        assert_eq!(result.unwrap_err().to_string(), "Liquidity is 0");
294
295        //fails if output amount is exactly the virtual reserves of token0
296        let result =
297            get_next_sqrt_price_from_output(U256::from_str("20282409603651670423947251286016").unwrap(), 1024, U256::from(4), false);
298        assert_eq!(result.unwrap_err().to_string(), "require((product = amount * sqrtPX96) / amount == sqrtPX96 && numerator1 > product);");
299
300        //fails if output amount is greater than virtual reserves of token0
301        let result =
302            get_next_sqrt_price_from_output(U256::from_str("20282409603651670423947251286016").unwrap(), 1024, U256::from(5), false);
303        assert_eq!(result.unwrap_err().to_string(), "require((product = amount * sqrtPX96) / amount == sqrtPX96 && numerator1 > product);");
304
305        //fails if output amount is greater than virtual reserves of token1
306        let result =
307            get_next_sqrt_price_from_output(U256::from_str("20282409603651670423947251286016").unwrap(), 1024, U256::from(262145), true);
308        assert_eq!(result.unwrap_err().to_string(), "Sqrt price is less than or equal to quotient");
309
310        //fails if output amount is exactly the virtual reserves of token1
311        let result =
312            get_next_sqrt_price_from_output(U256::from_str("20282409603651670423947251286016").unwrap(), 1024, U256::from(262144), true);
313        assert_eq!(result.unwrap_err().to_string(), "Sqrt price is less than or equal to quotient");
314
315        //succeeds if output amount is just less than the virtual
316        let result =
317            get_next_sqrt_price_from_output(U256::from_str("20282409603651670423947251286016").unwrap(), 1024, U256::from(262143), true);
318        assert_eq!(result.unwrap(), U256::from_str("77371252455336267181195264").unwrap());
319
320        //puzzling echidna test
321        let result =
322            get_next_sqrt_price_from_output(U256::from_str("20282409603651670423947251286016").unwrap(), 1024, U256::from(4), false);
323        assert_eq!(result.unwrap_err().to_string(), "require((product = amount * sqrtPX96) / amount == sqrtPX96 && numerator1 > product);");
324
325        //returns input price if amount in is zero and zeroForOne = true
326        let result =
327            get_next_sqrt_price_from_output(U256::from_str("79228162514264337593543950336").unwrap(), 1e17 as u128, U256::ZERO, true);
328        assert_eq!(result.unwrap(), U256::from_str("79228162514264337593543950336").unwrap());
329
330        //returns input price if amount in is zero and zeroForOne = false
331        let result =
332            get_next_sqrt_price_from_output(U256::from_str("79228162514264337593543950336").unwrap(), 1e17 as u128, U256::ZERO, false);
333        assert_eq!(result.unwrap(), U256::from_str("79228162514264337593543950336").unwrap());
334
335        //output amount of 0.1 token1
336        let result = get_next_sqrt_price_from_output(
337            U256::from_str("79228162514264337593543950336").unwrap(),
338            1e18 as u128,
339            U256::from(1e17 as u128),
340            false,
341        );
342        assert_eq!(result.unwrap(), U256::from_str("88031291682515930659493278152").unwrap());
343
344        //output amount of 0.1 token1
345        let result = get_next_sqrt_price_from_output(
346            U256::from_str("79228162514264337593543950336").unwrap(),
347            1e18 as u128,
348            U256::from(1e17 as u128),
349            true,
350        );
351        assert_eq!(result.unwrap(), U256::from_str("71305346262837903834189555302").unwrap());
352
353        //reverts if amountOut is impossible in zero for one direction
354        let result = get_next_sqrt_price_from_output(U256::from_str("79228162514264337593543950336").unwrap(), 1, U256::MAX, true);
355        assert_eq!(result.unwrap_err().to_string(), "Denominator is less than or equal to prod_1");
356
357        //reverts if amountOut is impossible in one for zero direction
358        let result = get_next_sqrt_price_from_output(U256::from_str("79228162514264337593543950336").unwrap(), 1, U256::MAX, false);
359        assert_eq!(result.unwrap_err().to_string(), "require((product = amount * sqrtPX96) / amount == sqrtPX96 && numerator1 > product);");
360    }
361
362    #[test]
363    fn test_get_amount_0_delta() {
364        // returns 0 if liquidity is 0
365        let amount_0 = _get_amount_0_delta(
366            U256::from_str("79228162514264337593543950336").unwrap(),
367            U256::from_str("79228162514264337593543950336").unwrap(),
368            0,
369            true,
370        );
371
372        assert_eq!(amount_0.unwrap(), U256::ZERO);
373
374        // returns 0 if prices are equal
375        let amount_0 = _get_amount_0_delta(
376            U256::from_str("79228162514264337593543950336").unwrap(),
377            U256::from_str("87150978765690771352898345369").unwrap(),
378            0,
379            true,
380        );
381
382        assert_eq!(amount_0.unwrap(), U256::ZERO);
383
384        // returns 0.1 amount1 for price of 1 to 1.21
385        let amount_0 = _get_amount_0_delta(
386            U256::from_str("79228162514264337593543950336").unwrap(),
387            U256::from_str("87150978765690771352898345369").unwrap(),
388            1e18 as u128,
389            true,
390        )
391        .unwrap();
392
393        assert_eq!(amount_0.clone(), U256::from_str("90909090909090910").unwrap());
394
395        let amount_0_rounded_down = _get_amount_0_delta(
396            U256::from_str("79228162514264337593543950336").unwrap(),
397            U256::from_str("87150978765690771352898345369").unwrap(),
398            1e18 as u128,
399            false,
400        );
401
402        assert_eq!(amount_0_rounded_down.unwrap(), amount_0.sub(U256_1));
403
404        // works for prices that overflow
405        let amount_0_up = _get_amount_0_delta(
406            U256::from_str("2787593149816327892691964784081045188247552").unwrap(),
407            U256::from_str("22300745198530623141535718272648361505980416").unwrap(),
408            1e18 as u128,
409            true,
410        )
411        .unwrap();
412
413        let amount_0_down = _get_amount_0_delta(
414            U256::from_str("2787593149816327892691964784081045188247552").unwrap(),
415            U256::from_str("22300745198530623141535718272648361505980416").unwrap(),
416            1e18 as u128,
417            false,
418        )
419        .unwrap();
420
421        assert_eq!(amount_0_up, amount_0_down.add(U256_1));
422    }
423
424    #[test]
425    fn test_get_amount_1_delta() {
426        // returns 0 if liquidity is 0
427        let amount_1 = _get_amount_1_delta(
428            U256::from_str("79228162514264337593543950336").unwrap(),
429            U256::from_str("79228162514264337593543950336").unwrap(),
430            0,
431            true,
432        );
433
434        assert_eq!(amount_1.unwrap(), U256::ZERO);
435
436        // returns 0 if prices are equal
437        let amount_1 = _get_amount_1_delta(
438            U256::from_str("79228162514264337593543950336").unwrap(),
439            U256::from_str("87150978765690771352898345369").unwrap(),
440            0,
441            true,
442        );
443
444        assert_eq!(amount_1.unwrap(), U256::ZERO);
445
446        // returns 0.1 amount1 for price of 1 to 1.21
447        let amount_1 = _get_amount_1_delta(
448            U256::from_str("79228162514264337593543950336").unwrap(),
449            U256::from_str("87150978765690771352898345369").unwrap(),
450            1e18 as u128,
451            true,
452        )
453        .unwrap();
454
455        assert_eq!(amount_1.clone(), U256::from_str("100000000000000000").unwrap());
456
457        let amount_1_rounded_down = _get_amount_1_delta(
458            U256::from_str("79228162514264337593543950336").unwrap(),
459            U256::from_str("87150978765690771352898345369").unwrap(),
460            1e18 as u128,
461            false,
462        );
463
464        assert_eq!(amount_1_rounded_down.unwrap(), amount_1.sub(U256_1));
465    }
466
467    #[test]
468    fn test_swap_computation() {
469        let sqrt_price = U256::from_str("1025574284609383690408304870162715216695788925244").unwrap();
470        let liquidity = 50015962439936049619261659728067971248;
471        let zero_for_one = true;
472        let amount_in = U256::from(406);
473
474        let sqrt_q = get_next_sqrt_price_from_input(sqrt_price, liquidity, amount_in, zero_for_one).unwrap();
475
476        assert_eq!(sqrt_q, U256::from_str("1025574284609383582644711336373707553698163132913").unwrap());
477
478        let amount_0_delta = _get_amount_0_delta(sqrt_q, sqrt_price, liquidity, true).unwrap();
479
480        assert_eq!(amount_0_delta, U256::from(406));
481    }
482}