1use alloy::primitives::{I256, U256};
2
3use crate::{
4 error::UniswapV3MathError,
5 full_math::{mul_div, mul_div_rounding_up},
6 unsafe_math::div_rounding_up,
7};
8
9pub const MAX_U160: U256 = U256::from_limbs([18446744073709551615, 18446744073709551615, 4294967295, 0]);
10pub const Q96: U256 = U256::from_limbs([0, 4294967296, 0, 0]);
11pub const FIXED_POINT_96_RESOLUTION: U256 = U256::from_limbs([96, 0, 0, 0]);
12
13pub fn get_next_sqrt_price_from_input(
15 sqrt_price: U256,
16 liquidity: u128,
17 amount_in: U256,
18 zero_for_one: bool,
19) -> Result<U256, UniswapV3MathError> {
20 if sqrt_price.is_zero() {
21 return Err(UniswapV3MathError::SqrtPriceIsZero);
22 } else if liquidity == 0 {
23 return Err(UniswapV3MathError::LiquidityIsZero);
24 }
25
26 if zero_for_one {
27 get_next_sqrt_price_from_amount_0_rounding_up(sqrt_price, liquidity, amount_in, true)
28 } else {
29 get_next_sqrt_price_from_amount_1_rounding_down(sqrt_price, liquidity, amount_in, true)
30 }
31}
32
33pub fn get_next_sqrt_price_from_output(
35 sqrt_price: U256,
36 liquidity: u128,
37 amount_out: U256,
38 zero_for_one: bool,
39) -> Result<U256, UniswapV3MathError> {
40 if sqrt_price.is_zero() {
41 return Err(UniswapV3MathError::SqrtPriceIsZero);
42 } else if liquidity == 0 {
43 return Err(UniswapV3MathError::LiquidityIsZero);
44 }
45
46 if zero_for_one {
47 get_next_sqrt_price_from_amount_1_rounding_down(sqrt_price, liquidity, amount_out, false)
48 } else {
49 get_next_sqrt_price_from_amount_0_rounding_up(sqrt_price, liquidity, amount_out, false)
50 }
51}
52
53pub fn get_next_sqrt_price_from_amount_0_rounding_up(
55 sqrt_price_x_96: U256,
56 liquidity: u128,
57 amount: U256,
58 add: bool,
59) -> Result<U256, UniswapV3MathError> {
60 if amount.is_zero() {
61 return Ok(sqrt_price_x_96);
62 }
63
64 let numerator_1: U256 = U256::from(liquidity) << 96;
65
66 if add {
67 let product = amount.wrapping_mul(sqrt_price_x_96);
68
69 if product.wrapping_div(amount) == sqrt_price_x_96 {
70 let denominator = numerator_1.wrapping_add(product);
71
72 if denominator >= numerator_1 {
73 return mul_div_rounding_up(numerator_1, sqrt_price_x_96, denominator);
74 }
75 }
76
77 Ok(div_rounding_up(numerator_1, (numerator_1.wrapping_div(sqrt_price_x_96)).wrapping_add(amount)))
78 } else {
79 let product = amount.wrapping_mul(sqrt_price_x_96);
80 if product.wrapping_div(amount) == sqrt_price_x_96 && numerator_1 > product {
81 let denominator = numerator_1.wrapping_sub(product);
82
83 mul_div_rounding_up(numerator_1, sqrt_price_x_96, denominator)
84 } else {
85 Err(UniswapV3MathError::ProductDivAmount)
86 }
87 }
88}
89
90pub fn get_next_sqrt_price_from_amount_1_rounding_down(
92 sqrt_price_x_96: U256,
93 liquidity: u128,
94 amount: U256,
95 add: bool,
96) -> Result<U256, UniswapV3MathError> {
97 let liquidity = U256::from(liquidity);
98
99 if add {
100 let quotient =
101 if amount <= MAX_U160 { (amount << FIXED_POINT_96_RESOLUTION) / liquidity } else { mul_div(amount, Q96, liquidity)? };
102
103 let next_sqrt_price = sqrt_price_x_96 + quotient;
104
105 if next_sqrt_price > MAX_U160 {
106 Err(UniswapV3MathError::SafeCastToU160Overflow)
107 } else {
108 Ok(next_sqrt_price)
109 }
110 } else {
111 let quotient = if amount <= MAX_U160 {
112 div_rounding_up(amount << FIXED_POINT_96_RESOLUTION, liquidity)
113 } else {
114 mul_div_rounding_up(amount, Q96, liquidity)?
115 };
116
117 if sqrt_price_x_96 <= quotient {
119 return Err(UniswapV3MathError::SqrtPriceIsLteQuotient);
120 }
121
122 Ok(sqrt_price_x_96.overflowing_sub(quotient).0)
123 }
124}
125
126pub fn _get_amount_0_delta(
128 mut sqrt_ratio_a_x_96: U256,
129 mut sqrt_ratio_b_x_96: U256,
130 liquidity: u128,
131 round_up: bool,
132) -> Result<U256, UniswapV3MathError> {
133 if sqrt_ratio_a_x_96 > sqrt_ratio_b_x_96 {
134 (sqrt_ratio_a_x_96, sqrt_ratio_b_x_96) = (sqrt_ratio_b_x_96, sqrt_ratio_a_x_96)
135 };
136
137 let numerator_1 = U256::from(liquidity) << 96;
138 let numerator_2 = sqrt_ratio_b_x_96 - sqrt_ratio_a_x_96;
139
140 if sqrt_ratio_a_x_96.is_zero() {
141 return Err(UniswapV3MathError::SqrtPriceIsZero);
142 }
143
144 if round_up {
145 let numerator_partial = mul_div_rounding_up(numerator_1, numerator_2, sqrt_ratio_b_x_96)?;
146 Ok(div_rounding_up(numerator_partial, sqrt_ratio_a_x_96))
147 } else {
148 Ok(mul_div(numerator_1, numerator_2, sqrt_ratio_b_x_96)? / sqrt_ratio_a_x_96)
149 }
150}
151
152pub fn _get_amount_1_delta(
154 mut sqrt_ratio_a_x_96: U256,
155 mut sqrt_ratio_b_x_96: U256,
156 liquidity: u128,
157 round_up: bool,
158) -> Result<U256, UniswapV3MathError> {
159 if sqrt_ratio_a_x_96 > sqrt_ratio_b_x_96 {
160 (sqrt_ratio_a_x_96, sqrt_ratio_b_x_96) = (sqrt_ratio_b_x_96, sqrt_ratio_a_x_96)
161 };
162
163 let denominator = U256::from_limbs([0, 4294967296, 0, 0]);
164
165 if round_up {
166 mul_div_rounding_up(U256::from(liquidity), sqrt_ratio_b_x_96 - sqrt_ratio_a_x_96, denominator)
167 } else {
168 mul_div(U256::from(liquidity), sqrt_ratio_b_x_96 - sqrt_ratio_a_x_96, denominator)
169 }
170}
171
172pub fn get_amount_0_delta(sqrt_ratio_a_x_96: U256, sqrt_ratio_b_x_96: U256, liquidity: i128) -> Result<I256, UniswapV3MathError> {
173 if liquidity < 0 {
174 Ok(-I256::from_raw(_get_amount_0_delta(sqrt_ratio_a_x_96, sqrt_ratio_b_x_96, -liquidity as u128, false)?))
175 } else {
176 Ok(I256::from_raw(_get_amount_0_delta(sqrt_ratio_a_x_96, sqrt_ratio_b_x_96, liquidity as u128, true)?))
177 }
178}
179
180pub fn get_amount_1_delta(sqrt_ratio_a_x_96: U256, sqrt_ratio_b_x_96: U256, liquidity: i128) -> Result<I256, UniswapV3MathError> {
181 if liquidity < 0 {
182 Ok(-I256::from_raw(_get_amount_1_delta(sqrt_ratio_a_x_96, sqrt_ratio_b_x_96, -liquidity as u128, false)?))
183 } else {
184 Ok(I256::from_raw(_get_amount_1_delta(sqrt_ratio_a_x_96, sqrt_ratio_b_x_96, liquidity as u128, true)?))
185 }
186}
187
188#[cfg(test)]
189mod test {
190 use std::{
191 ops::{Add, Sub},
192 str::FromStr,
193 };
194
195 use alloy::primitives::U256;
196
197 use crate::{
198 sqrt_price_math::{_get_amount_1_delta, get_next_sqrt_price_from_output, MAX_U160},
199 U256_1, U256_2,
200 };
201
202 use super::{_get_amount_0_delta, get_next_sqrt_price_from_input};
203
204 #[test]
205 fn test_get_next_sqrt_price_from_input() {
206 let result = get_next_sqrt_price_from_input(U256::ZERO, 0, U256::from(100000000000000000_u128), false);
208 assert_eq!(result.unwrap_err().to_string(), "Sqrt price is 0");
209
210 let result = get_next_sqrt_price_from_input(U256_1, 0, U256::from(100000000000000000_u128), true);
212 assert_eq!(result.unwrap_err().to_string(), "Liquidity is 0");
213
214 let result = get_next_sqrt_price_from_input(MAX_U160, 1024, U256::from(1024), false);
216 assert_eq!(result.unwrap_err().to_string(), "Overflow when casting to U160");
217
218 let result = get_next_sqrt_price_from_input(
220 U256_1,
221 1,
222 U256::from_str("57896044618658097711785492504343953926634992332820282019728792003956564819968").unwrap(),
223 true,
224 );
225
226 assert_eq!(result.unwrap(), U256_1);
227
228 let result =
230 get_next_sqrt_price_from_input(U256::from_str("79228162514264337593543950336").unwrap(), 1e17 as u128, U256::ZERO, true);
231
232 assert_eq!(result.unwrap(), U256::from_str("79228162514264337593543950336").unwrap());
233
234 let result =
236 get_next_sqrt_price_from_input(U256::from_str("79228162514264337593543950336").unwrap(), 1e17 as u128, U256::ZERO, true);
237
238 assert_eq!(result.unwrap(), U256::from_str("79228162514264337593543950336").unwrap());
239
240 let sqrt_price = MAX_U160;
243 let liquidity = u128::MAX;
244 let max_amount_no_overflow = U256::MAX - ((U256::from(liquidity) << 96) / sqrt_price);
245 let result = get_next_sqrt_price_from_input(sqrt_price, liquidity, max_amount_no_overflow, true);
246 assert_eq!(result.unwrap(), U256_1);
247
248 let result = get_next_sqrt_price_from_input(
250 U256::from_str("79228162514264337593543950336").unwrap(),
251 1e18 as u128,
252 U256::from_str("100000000000000000").unwrap(),
253 false,
254 );
255
256 assert_eq!(result.unwrap(), U256::from_str("87150978765690771352898345369").unwrap());
257
258 let result = get_next_sqrt_price_from_input(
260 U256::from_str("79228162514264337593543950336").unwrap(),
261 1e18 as u128,
262 U256::from_str("100000000000000000").unwrap(),
263 true,
264 );
265
266 assert_eq!(result.unwrap(), U256::from_str("72025602285694852357767227579").unwrap());
267
268 let result = get_next_sqrt_price_from_input(
270 U256::from_str("79228162514264337593543950336").unwrap(),
271 1e19 as u128,
272 U256::from_str("1267650600228229401496703205376").unwrap(),
273 true,
274 );
275 assert_eq!(result.unwrap(), U256::from_str("624999999995069620").unwrap());
278
279 let result = get_next_sqrt_price_from_input(U256::from_str("79228162514264337593543950336").unwrap(), 1, U256::MAX / U256_2, true);
281
282 assert_eq!(result.unwrap(), U256_1);
283 }
284
285 #[test]
286 fn test_get_next_sqrt_price_from_output() {
287 let result = get_next_sqrt_price_from_output(U256::ZERO, 0, U256::from(1000000000), false);
289 assert_eq!(result.unwrap_err().to_string(), "Sqrt price is 0");
290
291 let result = get_next_sqrt_price_from_output(U256_1, 0, U256::from(1000000000), false);
293 assert_eq!(result.unwrap_err().to_string(), "Liquidity is 0");
294
295 let result =
297 get_next_sqrt_price_from_output(U256::from_str("20282409603651670423947251286016").unwrap(), 1024, U256::from(4), false);
298 assert_eq!(result.unwrap_err().to_string(), "require((product = amount * sqrtPX96) / amount == sqrtPX96 && numerator1 > product);");
299
300 let result =
302 get_next_sqrt_price_from_output(U256::from_str("20282409603651670423947251286016").unwrap(), 1024, U256::from(5), false);
303 assert_eq!(result.unwrap_err().to_string(), "require((product = amount * sqrtPX96) / amount == sqrtPX96 && numerator1 > product);");
304
305 let result =
307 get_next_sqrt_price_from_output(U256::from_str("20282409603651670423947251286016").unwrap(), 1024, U256::from(262145), true);
308 assert_eq!(result.unwrap_err().to_string(), "Sqrt price is less than or equal to quotient");
309
310 let result =
312 get_next_sqrt_price_from_output(U256::from_str("20282409603651670423947251286016").unwrap(), 1024, U256::from(262144), true);
313 assert_eq!(result.unwrap_err().to_string(), "Sqrt price is less than or equal to quotient");
314
315 let result =
317 get_next_sqrt_price_from_output(U256::from_str("20282409603651670423947251286016").unwrap(), 1024, U256::from(262143), true);
318 assert_eq!(result.unwrap(), U256::from_str("77371252455336267181195264").unwrap());
319
320 let result =
322 get_next_sqrt_price_from_output(U256::from_str("20282409603651670423947251286016").unwrap(), 1024, U256::from(4), false);
323 assert_eq!(result.unwrap_err().to_string(), "require((product = amount * sqrtPX96) / amount == sqrtPX96 && numerator1 > product);");
324
325 let result =
327 get_next_sqrt_price_from_output(U256::from_str("79228162514264337593543950336").unwrap(), 1e17 as u128, U256::ZERO, true);
328 assert_eq!(result.unwrap(), U256::from_str("79228162514264337593543950336").unwrap());
329
330 let result =
332 get_next_sqrt_price_from_output(U256::from_str("79228162514264337593543950336").unwrap(), 1e17 as u128, U256::ZERO, false);
333 assert_eq!(result.unwrap(), U256::from_str("79228162514264337593543950336").unwrap());
334
335 let result = get_next_sqrt_price_from_output(
337 U256::from_str("79228162514264337593543950336").unwrap(),
338 1e18 as u128,
339 U256::from(1e17 as u128),
340 false,
341 );
342 assert_eq!(result.unwrap(), U256::from_str("88031291682515930659493278152").unwrap());
343
344 let result = get_next_sqrt_price_from_output(
346 U256::from_str("79228162514264337593543950336").unwrap(),
347 1e18 as u128,
348 U256::from(1e17 as u128),
349 true,
350 );
351 assert_eq!(result.unwrap(), U256::from_str("71305346262837903834189555302").unwrap());
352
353 let result = get_next_sqrt_price_from_output(U256::from_str("79228162514264337593543950336").unwrap(), 1, U256::MAX, true);
355 assert_eq!(result.unwrap_err().to_string(), "Denominator is less than or equal to prod_1");
356
357 let result = get_next_sqrt_price_from_output(U256::from_str("79228162514264337593543950336").unwrap(), 1, U256::MAX, false);
359 assert_eq!(result.unwrap_err().to_string(), "require((product = amount * sqrtPX96) / amount == sqrtPX96 && numerator1 > product);");
360 }
361
362 #[test]
363 fn test_get_amount_0_delta() {
364 let amount_0 = _get_amount_0_delta(
366 U256::from_str("79228162514264337593543950336").unwrap(),
367 U256::from_str("79228162514264337593543950336").unwrap(),
368 0,
369 true,
370 );
371
372 assert_eq!(amount_0.unwrap(), U256::ZERO);
373
374 let amount_0 = _get_amount_0_delta(
376 U256::from_str("79228162514264337593543950336").unwrap(),
377 U256::from_str("87150978765690771352898345369").unwrap(),
378 0,
379 true,
380 );
381
382 assert_eq!(amount_0.unwrap(), U256::ZERO);
383
384 let amount_0 = _get_amount_0_delta(
386 U256::from_str("79228162514264337593543950336").unwrap(),
387 U256::from_str("87150978765690771352898345369").unwrap(),
388 1e18 as u128,
389 true,
390 )
391 .unwrap();
392
393 assert_eq!(amount_0.clone(), U256::from_str("90909090909090910").unwrap());
394
395 let amount_0_rounded_down = _get_amount_0_delta(
396 U256::from_str("79228162514264337593543950336").unwrap(),
397 U256::from_str("87150978765690771352898345369").unwrap(),
398 1e18 as u128,
399 false,
400 );
401
402 assert_eq!(amount_0_rounded_down.unwrap(), amount_0.sub(U256_1));
403
404 let amount_0_up = _get_amount_0_delta(
406 U256::from_str("2787593149816327892691964784081045188247552").unwrap(),
407 U256::from_str("22300745198530623141535718272648361505980416").unwrap(),
408 1e18 as u128,
409 true,
410 )
411 .unwrap();
412
413 let amount_0_down = _get_amount_0_delta(
414 U256::from_str("2787593149816327892691964784081045188247552").unwrap(),
415 U256::from_str("22300745198530623141535718272648361505980416").unwrap(),
416 1e18 as u128,
417 false,
418 )
419 .unwrap();
420
421 assert_eq!(amount_0_up, amount_0_down.add(U256_1));
422 }
423
424 #[test]
425 fn test_get_amount_1_delta() {
426 let amount_1 = _get_amount_1_delta(
428 U256::from_str("79228162514264337593543950336").unwrap(),
429 U256::from_str("79228162514264337593543950336").unwrap(),
430 0,
431 true,
432 );
433
434 assert_eq!(amount_1.unwrap(), U256::ZERO);
435
436 let amount_1 = _get_amount_1_delta(
438 U256::from_str("79228162514264337593543950336").unwrap(),
439 U256::from_str("87150978765690771352898345369").unwrap(),
440 0,
441 true,
442 );
443
444 assert_eq!(amount_1.unwrap(), U256::ZERO);
445
446 let amount_1 = _get_amount_1_delta(
448 U256::from_str("79228162514264337593543950336").unwrap(),
449 U256::from_str("87150978765690771352898345369").unwrap(),
450 1e18 as u128,
451 true,
452 )
453 .unwrap();
454
455 assert_eq!(amount_1.clone(), U256::from_str("100000000000000000").unwrap());
456
457 let amount_1_rounded_down = _get_amount_1_delta(
458 U256::from_str("79228162514264337593543950336").unwrap(),
459 U256::from_str("87150978765690771352898345369").unwrap(),
460 1e18 as u128,
461 false,
462 );
463
464 assert_eq!(amount_1_rounded_down.unwrap(), amount_1.sub(U256_1));
465 }
466
467 #[test]
468 fn test_swap_computation() {
469 let sqrt_price = U256::from_str("1025574284609383690408304870162715216695788925244").unwrap();
470 let liquidity = 50015962439936049619261659728067971248;
471 let zero_for_one = true;
472 let amount_in = U256::from(406);
473
474 let sqrt_q = get_next_sqrt_price_from_input(sqrt_price, liquidity, amount_in, zero_for_one).unwrap();
475
476 assert_eq!(sqrt_q, U256::from_str("1025574284609383582644711336373707553698163132913").unwrap());
477
478 let amount_0_delta = _get_amount_0_delta(sqrt_q, sqrt_price, liquidity, true).unwrap();
479
480 assert_eq!(amount_0_delta, U256::from(406));
481 }
482}