kabu_defi_uniswap_v3_math/
full_math.rs1use std::ops::{Add, BitOrAssign, Div, Mul, MulAssign};
2
3use crate::{error::UniswapV3MathError, U256_1};
4use alloy::primitives::{Uint, U256};
7
8pub const ONE: Uint<256, 4> = Uint::<256, 4>::from_limbs([1, 0, 0, 0]);
9pub const TWO: Uint<256, 4> = Uint::<256, 4>::from_limbs([2, 0, 0, 0]);
10pub const THREE: Uint<256, 4> = Uint::<256, 4>::from_limbs([3, 0, 0, 0]);
11
12pub fn mul_div(a: U256, b: U256, mut denominator: U256) -> Result<U256, UniswapV3MathError> {
14 let mm = a.mul_mod(b, U256::MAX);
20
21 let mut prod_0 = a.overflowing_mul(b).0; let mut prod_1 = mm.overflowing_sub(prod_0).0.overflowing_sub(U256::from((mm < prod_0) as u8)).0;
23
24 if prod_1 == U256::ZERO {
26 if denominator == U256::ZERO {
27 return Err(UniswapV3MathError::DenominatorIsZero);
28 }
29 return Ok(U256::from_limbs(*prod_0.div(denominator).as_limbs()));
30 }
31
32 if denominator <= prod_1 {
35 return Err(UniswapV3MathError::DenominatorIsLteProdOne);
36 }
37
38 let remainder = a.mul_mod(b, denominator);
46
47 prod_1 = prod_1.overflowing_sub(U256::from((remainder > prod_0) as u8)).0;
49 prod_0 = prod_0.overflowing_sub(remainder).0;
50
51 let mut twos = U256::ZERO.overflowing_sub(denominator).0.bitand(denominator);
55
56 denominator = denominator.wrapping_div(twos);
59
60 prod_0 = prod_0.wrapping_div(twos);
62
63 twos = (U256::ZERO.overflowing_sub(twos).0.wrapping_div(twos)).add(U256_1);
68
69 prod_0.bitor_assign(prod_1 * twos);
70
71 let mut inv = THREE.mul(denominator).bitxor(TWO);
78
79 inv.mul_assign(TWO - denominator * inv); inv.mul_assign(TWO - denominator * inv); inv.mul_assign(TWO - denominator * inv); inv.mul_assign(TWO - denominator * inv); inv.mul_assign(TWO - denominator * inv); inv.mul_assign(TWO - denominator * inv); Ok(U256::from_le_slice((prod_0 * inv).as_le_slice()))
98}
99
100pub fn mul_div_rounding_up(a: U256, b: U256, denominator: U256) -> Result<U256, UniswapV3MathError> {
101 let result = mul_div(a, b, denominator)?;
102
103 if a.mul_mod(b, denominator) > U256::ZERO {
104 if result == U256::MAX {
105 Err(UniswapV3MathError::ResultIsU256MAX)
106 } else {
107 Ok(result + U256_1)
108 }
109 } else {
110 Ok(result)
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 const Q128: U256 = U256::from_limbs([0, 0, 1, 0]);
119
120 #[test]
121 fn test_mul_div() {
122 let result = mul_div(Q128, U256::from(5), U256::ZERO);
124 assert_eq!(result.err().unwrap().to_string(), "Denominator is 0");
125
126 let result = mul_div(Q128, Q128, U256::ZERO);
128 assert_eq!(result.err().unwrap().to_string(), "Denominator is less than or equal to prod_1");
129
130 let result = mul_div(Q128, Q128, U256_1);
132 assert_eq!(result.err().unwrap().to_string(), "Denominator is less than or equal to prod_1");
133 }
134}
135
136#[cfg(test)]
137mod test {
138
139 use std::ops::{Div, Mul, Sub};
140
141 use alloy::primitives::U256;
142
143 use crate::U256_1;
144
145 use super::mul_div;
146
147 const Q128: U256 = U256::from_limbs([0, 0, 1, 0]);
148
149 #[test]
150 fn test_mul_div() {
151 let result = mul_div(Q128, U256::from(5), U256::ZERO);
153 assert_eq!(result.err().unwrap().to_string(), "Denominator is 0");
154
155 let result = mul_div(Q128, Q128, U256::ZERO);
157 assert_eq!(result.err().unwrap().to_string(), "Denominator is less than or equal to prod_1");
158
159 let result = mul_div(Q128, Q128, U256_1);
161 assert_eq!(result.err().unwrap().to_string(), "Denominator is less than or equal to prod_1");
162
163 let result = mul_div(U256::MAX, U256::MAX, U256::MAX.sub(U256_1));
165 assert_eq!(result.err().unwrap().to_string(), "Denominator is less than or equal to prod_1");
166
167 let result = mul_div(U256::MAX, U256::MAX, U256::MAX);
169 assert_eq!(result.unwrap(), U256::MAX);
170
171 let result = mul_div(Q128, U256::from(50).mul(Q128).div(U256::from(100)), U256::from(150).mul(Q128).div(U256::from(100)));
173 assert_eq!(result.unwrap(), Q128.div(U256::from(3)));
174
175 let result = mul_div(Q128, U256::from(35).mul(Q128), U256::from(8).mul(Q128));
177 assert_eq!(result.unwrap(), U256::from(4375).mul(Q128).div(U256::from(1000)));
178
179 let result = mul_div(Q128, U256::from(1000).mul(Q128), U256::from(3000).mul(Q128));
181 assert_eq!(result.unwrap(), Q128.div(U256::from(3)));
182 }
183}