kabu_defi_uniswap_v3_math/
full_math.rs

1use std::ops::{Add, BitOrAssign, Div, Mul, MulAssign};
2
3use crate::{error::UniswapV3MathError, U256_1};
4// use alloy_primitives::utils::ParseUnits::U256;
5
6use alloy::primitives::{Uint, U256};
7
8pub const ONE: Uint<256, 4> = Uint::<256, 4>::from_limbs([1, 0, 0, 0]);
9pub const TWO: Uint<256, 4> = Uint::<256, 4>::from_limbs([2, 0, 0, 0]);
10pub const THREE: Uint<256, 4> = Uint::<256, 4>::from_limbs([3, 0, 0, 0]);
11
12// returns (uint256 result)
13pub fn mul_div(a: U256, b: U256, mut denominator: U256) -> Result<U256, UniswapV3MathError> {
14    // 512-bit multiply [prod1 prod0] = a * b
15    // Compute the product mod 2**256 and mod 2**256 - 1
16    // then use the Chinese Remainder Theorem to reconstruct
17    // the 512 bit result. The result is stored in two 256
18    // variables such that product = prod1 * 2**256 + prod0
19    let mm = a.mul_mod(b, U256::MAX);
20
21    let mut prod_0 = a.overflowing_mul(b).0; // Least significant 256 bits of the product
22    let mut prod_1 = mm.overflowing_sub(prod_0).0.overflowing_sub(U256::from((mm < prod_0) as u8)).0;
23
24    // Handle non-overflow cases, 256 by 256 division
25    if prod_1 == U256::ZERO {
26        if denominator == U256::ZERO {
27            return Err(UniswapV3MathError::DenominatorIsZero);
28        }
29        return Ok(U256::from_limbs(*prod_0.div(denominator).as_limbs()));
30    }
31
32    // Make sure the result is less than 2**256.
33    // Also prevents denominator == 0
34    if denominator <= prod_1 {
35        return Err(UniswapV3MathError::DenominatorIsLteProdOne);
36    }
37
38    ///////////////////////////////////////////////
39    // 512 by 256 division.
40    ///////////////////////////////////////////////
41    //
42
43    // Make division exact by subtracting the remainder from [prod1 prod0]
44    // Compute remainder using mulmod
45    let remainder = a.mul_mod(b, denominator);
46
47    // Subtract 256 bit number from 512 bit number
48    prod_1 = prod_1.overflowing_sub(U256::from((remainder > prod_0) as u8)).0;
49    prod_0 = prod_0.overflowing_sub(remainder).0;
50
51    // Factor powers of two out of denominator
52    // Compute largest power of two divisor of denominator.
53    // Always >= 1.
54    let mut twos = U256::ZERO.overflowing_sub(denominator).0.bitand(denominator);
55
56    // Divide denominator by power of two
57
58    denominator = denominator.wrapping_div(twos);
59
60    // Divide [prod1 prod0] by the factors of two
61    prod_0 = prod_0.wrapping_div(twos);
62
63    // Shift in bits from prod1 into prod0. For this we need
64    // to flip `twos` such that it is 2**256 / twos.
65    // If twos is zero, then it becomes one
66
67    twos = (U256::ZERO.overflowing_sub(twos).0.wrapping_div(twos)).add(U256_1);
68
69    prod_0.bitor_assign(prod_1 * twos);
70
71    // Invert denominator mod 2**256
72    // Now that denominator is an odd number, it has an inverse
73    // modulo 2**256 such that denominator * inv = 1 mod 2**256.
74    // Compute the inverse by starting with a seed that is correct
75    // correct for four bits. That is, denominator * inv = 1 mod 2**4
76
77    let mut inv = THREE.mul(denominator).bitxor(TWO);
78
79    // Now use Newton-Raphson iteration to improve the precision.
80    // Thanks to Hensel's lifting lemma, this also works in modular
81    // arithmetic, doubling the correct bits in each step.
82
83    inv.mul_assign(TWO - denominator * inv); // inverse mod 2**8
84    inv.mul_assign(TWO - denominator * inv); // inverse mod 2**16
85    inv.mul_assign(TWO - denominator * inv); // inverse mod 2**32
86    inv.mul_assign(TWO - denominator * inv); // inverse mod 2**64
87    inv.mul_assign(TWO - denominator * inv); // inverse mod 2**128
88    inv.mul_assign(TWO - denominator * inv); // inverse mod 2**256
89
90    // Because the division is now exact we can divide by multiplying
91    // with the modular inverse of denominator. This will give us the
92    // correct result modulo 2**256. Since the precoditions guarantee
93    // that the outcome is less than 2**256, this is the final result.
94    // We don't need to compute the high bits of the result and prod1
95    // is no longer required.
96
97    Ok(U256::from_le_slice((prod_0 * inv).as_le_slice()))
98}
99
100pub fn mul_div_rounding_up(a: U256, b: U256, denominator: U256) -> Result<U256, UniswapV3MathError> {
101    let result = mul_div(a, b, denominator)?;
102
103    if a.mul_mod(b, denominator) > U256::ZERO {
104        if result == U256::MAX {
105            Err(UniswapV3MathError::ResultIsU256MAX)
106        } else {
107            Ok(result + U256_1)
108        }
109    } else {
110        Ok(result)
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    const Q128: U256 = U256::from_limbs([0, 0, 1, 0]);
119
120    #[test]
121    fn test_mul_div() {
122        //Revert if the denominator is zero
123        let result = mul_div(Q128, U256::from(5), U256::ZERO);
124        assert_eq!(result.err().unwrap().to_string(), "Denominator is 0");
125
126        // Revert if the denominator is zero and numerator overflows
127        let result = mul_div(Q128, Q128, U256::ZERO);
128        assert_eq!(result.err().unwrap().to_string(), "Denominator is less than or equal to prod_1");
129
130        // Revert if the output overflows uint256
131        let result = mul_div(Q128, Q128, U256_1);
132        assert_eq!(result.err().unwrap().to_string(), "Denominator is less than or equal to prod_1");
133    }
134}
135
136#[cfg(test)]
137mod test {
138
139    use std::ops::{Div, Mul, Sub};
140
141    use alloy::primitives::U256;
142
143    use crate::U256_1;
144
145    use super::mul_div;
146
147    const Q128: U256 = U256::from_limbs([0, 0, 1, 0]);
148
149    #[test]
150    fn test_mul_div() {
151        //Revert if the denominator is zero
152        let result = mul_div(Q128, U256::from(5), U256::ZERO);
153        assert_eq!(result.err().unwrap().to_string(), "Denominator is 0");
154
155        // Revert if the denominator is zero and numerator overflows
156        let result = mul_div(Q128, Q128, U256::ZERO);
157        assert_eq!(result.err().unwrap().to_string(), "Denominator is less than or equal to prod_1");
158
159        // Revert if the output overflows uint256
160        let result = mul_div(Q128, Q128, U256_1);
161        assert_eq!(result.err().unwrap().to_string(), "Denominator is less than or equal to prod_1");
162
163        // Reverts on overflow with all max inputs
164        let result = mul_div(U256::MAX, U256::MAX, U256::MAX.sub(U256_1));
165        assert_eq!(result.err().unwrap().to_string(), "Denominator is less than or equal to prod_1");
166
167        // All max inputs
168        let result = mul_div(U256::MAX, U256::MAX, U256::MAX);
169        assert_eq!(result.unwrap(), U256::MAX);
170
171        // Accurate without phantom overflow
172        let result = mul_div(Q128, U256::from(50).mul(Q128).div(U256::from(100)), U256::from(150).mul(Q128).div(U256::from(100)));
173        assert_eq!(result.unwrap(), Q128.div(U256::from(3)));
174
175        // Accurate with phantom overflow
176        let result = mul_div(Q128, U256::from(35).mul(Q128), U256::from(8).mul(Q128));
177        assert_eq!(result.unwrap(), U256::from(4375).mul(Q128).div(U256::from(1000)));
178
179        // Accurate with phantom overflow and repeating decimal
180        let result = mul_div(Q128, U256::from(1000).mul(Q128), U256::from(3000).mul(Q128));
181        assert_eq!(result.unwrap(), Q128.div(U256::from(3)));
182    }
183}