1use std::fmt;
2use std::hash::{Hash, Hasher};
3use std::sync::Arc;
4
5use crate::swap_path::SwapPath;
6use crate::{CalculationResult, PoolId, PoolWrapper, SwapError, SwapStep, Token};
7use alloy_primitives::{Address, I256, U256};
8use eyre::{eyre, Result};
9use kabu_evm_db::KabuDBError;
10use revm::DatabaseRef;
11use tracing::debug;
12
13#[derive(Debug, Clone, Default)]
14pub enum SwapAmountType {
15 #[default]
16 NotSet,
17 Set(U256),
18 Stack0,
19 RelativeStack(u32),
20 Balance(Address),
21}
22
23impl Copy for SwapAmountType {}
24
25impl SwapAmountType {
26 #[inline]
27 pub fn unwrap(&self) -> U256 {
28 match &self {
29 Self::Set(x) => *x,
30 _ => panic!("called `InAmountType::unwrap()` on a unknown value"),
31 }
32 }
33 #[inline]
34 pub fn unwrap_or_default(&self) -> U256 {
35 match &self {
36 Self::Set(x) => *x,
37 _ => U256::ZERO,
38 }
39 }
40
41 #[inline]
42 pub fn is_set(&self) -> bool {
43 matches!(self, Self::Set(_))
44 }
45 #[inline]
46 pub fn is_not_set(&self) -> bool {
47 !matches!(self, Self::Set(_))
48 }
49}
50
51#[derive(Clone, Debug, Default)]
52pub struct SwapLine {
53 pub path: SwapPath,
54 pub amount_in: SwapAmountType,
56 pub amount_out: SwapAmountType,
58 pub calculation_results: Vec<CalculationResult>,
60 pub swap_to: Option<Address>,
62 pub gas_used: Option<u64>,
64}
65
66impl fmt::Display for SwapLine {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 let token_in = self.tokens().first();
69 let token_out = self.tokens().last();
70
71 let profit: String = if token_in == token_out {
72 match token_in {
73 Some(t) => format!("profit={}", t.to_float_sign(self.profit().unwrap_or(I256::ZERO))),
74 _ => format!("profit={}", self.profit().unwrap_or(I256::ZERO)),
75 }
76 } else {
77 "-".to_string()
78 };
79
80 let tokens = self.tokens().iter().map(|token| token.get_symbol()).collect::<Vec<String>>().join(", ");
81 let pools =
82 self.pools().iter().map(|pool| format!("{}@{}", pool.get_protocol(), pool.get_pool_id())).collect::<Vec<String>>().join(", ");
83 let amount_in = match self.amount_in {
84 SwapAmountType::Set(x) => match token_in {
85 Some(t) => format!("{:?}", t.to_float(x)),
86 _ => format!("{x}"),
87 },
88 _ => {
89 format!("{:?}", self.amount_in)
90 }
91 };
92 let amount_out = match self.amount_out {
93 SwapAmountType::Set(x) => match token_out {
94 Some(t) => format!("{:?}", t.to_float(x)),
95 _ => format!("{x}"),
96 },
97 _ => {
98 format!("{:?}", self.amount_out)
99 }
100 };
101
102 let calculation_results =
103 self.calculation_results.iter().map(|calculation_result| format!("{calculation_result}")).collect::<Vec<String>>().join(", ");
104
105 write!(
106 f,
107 "SwapLine [{}, tokens=[{}], pools=[{}], amount_in={}, amount_out={}, calculation_results=[{}], gas_used={:?}]",
108 profit, tokens, pools, amount_in, amount_out, calculation_results, self.gas_used
109 )
110 }
111}
112
113impl Hash for SwapLine {
114 fn hash<H: Hasher>(&self, state: &mut H) {
115 self.tokens().hash(state);
116 self.pools().hash(state);
117 }
118}
119
120impl PartialEq for SwapLine {
121 fn eq(&self, other: &Self) -> bool {
122 self.tokens() == other.tokens() && self.pools() == other.pools()
123 }
124}
125
126impl From<SwapPath> for SwapLine {
127 fn from(value: SwapPath) -> Self {
128 Self { path: value, ..Default::default() }
129 }
130}
131
132impl SwapLine {
133 pub fn to_error(&self, msg: String) -> SwapError {
134 SwapError {
135 msg,
136 pool: self.get_first_pool().map(|x| x.get_pool_id()).unwrap_or(PoolId::Address(Address::default())),
137 token_from: self.get_first_token().map_or(Address::default(), |x| x.get_address()),
138 token_to: self.get_last_token().map_or(Address::default(), |x| x.get_address()),
139 is_in_amount: true,
140 amount: self.amount_in.unwrap_or_default(),
141 }
142 }
143
144 pub fn new() -> Self {
145 Default::default()
146 }
147
148 pub fn contains_pool(&self, pool: &PoolWrapper) -> bool {
150 self.path.contains_pool(pool)
151 }
152
153 pub fn tokens(&self) -> &Vec<Arc<Token>> {
155 &self.path.tokens
156 }
157
158 pub fn pools(&self) -> &Vec<PoolWrapper> {
160 &self.path.pools
161 }
162
163 pub fn get_first_token(&self) -> Option<&Arc<Token>> {
165 self.tokens().first()
166 }
167
168 pub fn get_last_token(&self) -> Option<&Arc<Token>> {
170 self.tokens().last()
171 }
172
173 pub fn get_first_pool(&self) -> Option<&PoolWrapper> {
175 self.pools().first()
176 }
177
178 pub fn get_last_pool(&self) -> Option<&PoolWrapper> {
180 self.pools().last()
181 }
182
183 pub fn to_swap_steps(&self, multicaller: Address) -> Option<(SwapStep, SwapStep)> {
185 let mut sp0: Option<SwapLine> = None;
186 let mut sp1: Option<SwapLine> = None;
187
188 for i in 1..self.path.pool_count() {
189 let (head_path, mut tail_path) = self.split(i).unwrap();
190 if head_path.can_flash_swap() || tail_path.can_flash_swap() {
191 if head_path.can_flash_swap() {
192 tail_path.amount_in = SwapAmountType::Stack0;
193 }
194 sp0 = Some(head_path);
195 sp1 = Some(tail_path);
196 break;
197 }
198 }
199
200 if sp0.is_none() || sp1.is_none() {
201 let (head_path, tail_path) = self.split(1).unwrap();
202 sp0 = Some(head_path);
203 sp1 = Some(tail_path);
204 }
205
206 let mut step_0 = SwapStep::new(multicaller);
207 step_0.add(sp0.unwrap());
208
209 let mut step_1 = SwapStep::new(multicaller);
210 let sp1 = sp1.unwrap();
211 step_1.add(sp1);
212
213 Some((step_0, step_1))
214 }
215
216 pub fn split(&self, pool_index: usize) -> Result<(SwapLine, SwapLine)> {
218 let first = SwapLine {
219 path: SwapPath::new(self.tokens()[0..pool_index + 1].to_vec(), self.pools()[0..pool_index].to_vec()),
220 amount_in: self.amount_in,
221 amount_out: SwapAmountType::NotSet,
222 calculation_results: vec![],
223 swap_to: None,
224 gas_used: None,
225 };
226 let second = SwapLine {
227 path: SwapPath::new(self.tokens()[pool_index..].to_vec(), self.pools()[pool_index..].to_vec()),
228 amount_in: SwapAmountType::NotSet,
229 amount_out: self.amount_out,
230 calculation_results: vec![],
231 swap_to: None,
232 gas_used: None,
233 };
234 Ok((first, second))
235 }
236
237 pub fn can_flash_swap(&self) -> bool {
239 for pool in self.pools().iter() {
240 if !pool.can_flash_swap() {
241 return false;
242 }
243 }
244 true
245 }
246
247 pub fn abs_profit(&self) -> U256 {
249 let Some(token_in) = self.tokens().first() else {
250 return U256::ZERO;
251 };
252 let Some(token_out) = self.tokens().last() else {
253 return U256::ZERO;
254 };
255 if token_in != token_out {
256 return U256::ZERO;
257 }
258 let SwapAmountType::Set(amount_in) = self.amount_in else {
259 return U256::ZERO;
260 };
261 let SwapAmountType::Set(amount_out) = self.amount_out else {
262 return U256::ZERO;
263 };
264 if amount_out > amount_in {
265 return amount_out - amount_in;
266 }
267
268 U256::ZERO
269 }
270
271 pub fn abs_profit_eth(&self) -> U256 {
273 let profit = self.abs_profit();
274 let Some(first_token) = self.get_first_token() else {
275 return U256::ZERO;
276 };
277 first_token.calc_eth_value(profit).unwrap_or(U256::ZERO)
278 }
279
280 pub fn profit(&self) -> Result<I256> {
281 if self.tokens().len() < 3 {
282 return Err(eyre!("NOT_ARB_PATH"));
283 }
284 if let Some(token_in) = self.tokens().first() {
285 if let Some(token_out) = self.tokens().last() {
286 return if token_in == token_out {
287 if let SwapAmountType::Set(amount_in) = self.amount_in {
288 if let SwapAmountType::Set(amount_out) = self.amount_out {
289 return Ok(I256::from_raw(amount_out) - I256::from_raw(amount_in));
290 }
291 }
292 Err(eyre!("AMOUNTS_NOT_SET"))
293 } else {
294 Err(eyre!("TOKENS_DONT_MATCH"))
295 };
296 }
297 }
298 Err(eyre!("CANNOT_CALCULATE"))
299 }
300
301 const MIN_VALID_OUT_AMOUNT: U256 = U256::from_limbs([0x100, 0, 0, 0]);
302
303 #[allow(clippy::result_large_err)]
305 pub fn calculate_with_in_amount<DB: DatabaseRef<Error = KabuDBError>>(
306 &self,
307 db: &DB,
308 in_amount: U256,
309 ) -> Result<(U256, u64, Vec<CalculationResult>), SwapError> {
310 let mut current_in_amount = in_amount;
311 let mut final_out_amount = U256::ZERO;
312 let mut gas_used = 0;
313 let mut calculation_results = vec![];
314
315 for (i, pool) in self.pools().iter().enumerate() {
316 let token_from = &self.tokens()[i];
317 let token_to = &self.tokens()[i + 1];
318 match pool.calculate_out_amount(db, &token_from.get_address(), &token_to.get_address(), current_in_amount) {
319 Ok((out_amount_result, gas_result)) => {
320 if out_amount_result.is_zero() {
321 return Err(SwapError {
322 msg: "ZERO_OUT_AMOUNT".to_string(),
323 pool: pool.get_pool_id(),
324 token_from: token_from.get_address(),
325 token_to: token_to.get_address(),
326 is_in_amount: true,
327 amount: current_in_amount,
328 });
329 }
330 if out_amount_result.lt(&Self::MIN_VALID_OUT_AMOUNT) {
331 return Err(SwapError {
332 msg: "ALMOST_ZERO_OUT_AMOUNT".to_string(),
333 pool: pool.get_pool_id(),
334 token_from: token_from.get_address(),
335 token_to: token_to.get_address(),
336 is_in_amount: true,
337 amount: current_in_amount,
338 });
339 }
340
341 calculation_results.push(CalculationResult::new(current_in_amount, out_amount_result));
342 current_in_amount = out_amount_result;
343 final_out_amount = out_amount_result;
344 gas_used += gas_result
345 }
346 Err(e) => {
347 return Err(SwapError {
349 msg: e.to_string(),
350 pool: pool.get_pool_id(),
351 token_from: token_from.get_address(),
352 token_to: token_to.get_address(),
353 is_in_amount: true,
354 amount: current_in_amount,
355 });
356 }
357 }
358 }
359 Ok((final_out_amount, gas_used, calculation_results))
360 }
361
362 #[allow(clippy::result_large_err)]
364 pub fn calculate_with_out_amount<DB: DatabaseRef<Error = KabuDBError>>(
365 &self,
366 db: &DB,
367 out_amount: U256,
368 ) -> Result<(U256, u64, Vec<CalculationResult>), SwapError> {
369 let mut current_out_amount = out_amount;
370 let mut final_in_amount = U256::ZERO;
371 let mut gas_used = 0;
372 let mut calculation_results = vec![];
373
374 let mut pool_reverse = self.pools().clone();
376 pool_reverse.reverse();
377 let mut tokens_reverse = self.tokens().clone();
378 tokens_reverse.reverse();
379
380 for (i, pool) in pool_reverse.iter().enumerate() {
381 let token_from = &tokens_reverse[i + 1];
382 let token_to = &tokens_reverse[i];
383 match pool.calculate_in_amount(db, &token_from.get_address(), &token_to.get_address(), current_out_amount) {
384 Ok((in_amount_result, gas_result)) => {
385 if in_amount_result == U256::MAX || in_amount_result == U256::ZERO {
386 return Err(SwapError {
387 msg: "ZERO_AMOUNT".to_string(),
388 pool: pool.get_pool_id(),
389 token_from: token_from.get_address(),
390 token_to: token_to.get_address(),
391 is_in_amount: false,
392 amount: current_out_amount,
393 });
394 }
395 calculation_results.push(CalculationResult::new(current_out_amount, in_amount_result));
396 current_out_amount = in_amount_result;
397 final_in_amount = in_amount_result;
398 gas_used += gas_result;
399 }
400 Err(e) => {
401 return Err(SwapError {
404 msg: e.to_string(),
405 pool: pool.get_pool_id(),
406 token_from: token_from.get_address(),
407 token_to: token_to.get_address(),
408 is_in_amount: false,
409 amount: current_out_amount,
410 });
411 }
412 }
413 }
414 Ok((final_in_amount, gas_used, calculation_results))
415 }
416
417 #[allow(clippy::result_large_err)]
419 pub fn optimize_with_in_amount<DB: DatabaseRef<Error = KabuDBError>>(
420 &mut self,
421 db: &DB,
422 in_amount: U256,
423 ) -> Result<&mut Self, SwapError> {
424 let mut current_in_amount = in_amount;
425 let mut best_profit: Option<I256> = None;
426 let mut current_step = U256::from(10000);
427 let mut inc_direction = true;
428 let mut first_step_change = false;
429 let mut next_amount = current_in_amount;
430 let mut prev_in_amount = U256::ZERO;
431 let mut counter = 0;
432 let denominator = U256::from(1000);
433
434 loop {
435 counter += 1;
436 if counter > 30 {
439 debug!("optimize_swap_path_in_amount iterations exceeded : {self} {current_in_amount} {current_step}");
440 return Ok(self);
441 }
442
443 let (current_out_amount, current_gas_used, calculation_results) = match self.calculate_with_in_amount(db, next_amount) {
444 Ok(ret) => ret,
445 Err(e) => {
446 if counter == 1 {
447 return Err(e);
449 }
450 (U256::ZERO, 0, vec![])
451 }
452 };
453
454 let current_profit = I256::from_raw(current_out_amount) - I256::from_raw(next_amount);
455
456 if best_profit.is_none() {
457 best_profit = Some(current_profit);
458 self.amount_in = SwapAmountType::Set(next_amount);
459 self.amount_out = SwapAmountType::Set(current_out_amount);
460 self.gas_used = Some(current_gas_used);
461 self.calculation_results = calculation_results;
462 current_in_amount = next_amount;
463 if current_out_amount.is_zero() || current_profit.is_negative() {
464 return Ok(self);
465 }
466 } else if best_profit.unwrap() > current_profit || current_out_amount.is_zero()
467 {
469 if first_step_change && inc_direction && current_step < denominator {
470 inc_direction = false;
471 next_amount = prev_in_amount;
473 current_in_amount = prev_in_amount;
474 first_step_change = true;
475 } else if first_step_change && !inc_direction {
477 inc_direction = true;
479 current_step /= U256::from(10);
480 best_profit = Some(current_profit);
481 first_step_change = true;
482 if current_step == U256::from(1) {
485 break;
486 }
487 } else {
488 current_step /= U256::from(10);
489 first_step_change = true;
490 if current_step == U256::from(1) {
491 break;
492 }
493 }
494 } else {
495 best_profit = Some(current_profit);
496 self.amount_in = SwapAmountType::Set(next_amount);
497 self.amount_out = SwapAmountType::Set(current_out_amount);
498 self.gas_used = Some(current_gas_used);
499 self.calculation_results = calculation_results;
500 current_in_amount = next_amount;
501 first_step_change = false;
502 }
503
504 prev_in_amount = current_in_amount;
505 if inc_direction {
506 next_amount = current_in_amount + (current_in_amount * current_step / denominator);
507 } else {
508 next_amount = current_in_amount - (current_in_amount * current_step / denominator);
509 }
510 }
512
513 Ok(self)
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520 use crate::mock_pool::MockPool;
521 use alloy_primitives::utils::parse_units;
522 use alloy_primitives::Address;
523 use kabu_defi_address_book::{TokenAddressEth, UniswapV2PoolAddress, UniswapV3PoolAddress};
524 use std::sync::Arc;
525
526 fn default_swap_line() -> (MockPool, MockPool, SwapLine) {
527 let token0 = Arc::new(Token::new_with_data(TokenAddressEth::WETH, Some("WETH".to_string()), None, Some(18), true, false));
528 let token1 = Arc::new(Token::new_with_data(TokenAddressEth::USDT, Some("USDT".to_string()), None, Some(6), true, false));
529 let pool1 =
530 MockPool { token0: TokenAddressEth::WETH, token1: TokenAddressEth::USDT, address: UniswapV3PoolAddress::WETH_USDT_3000 };
531 let pool2_address = Address::random();
532 let pool2 = MockPool { token0: TokenAddressEth::WETH, token1: TokenAddressEth::USDT, address: UniswapV2PoolAddress::WETH_USDT };
533
534 let swap_path =
535 SwapPath::new(vec![token0.clone(), token1.clone(), token1.clone(), token0.clone()], vec![pool1.clone(), pool2.clone()]);
536
537 let swap_line = SwapLine {
538 path: swap_path,
539 amount_in: SwapAmountType::Set(parse_units("0.01", "ether").unwrap().get_absolute()),
540 amount_out: SwapAmountType::Set(parse_units("0.03", "ether").unwrap().get_absolute()),
541 calculation_results: vec![],
542 swap_to: Some(Default::default()),
543 gas_used: Some(10000),
544 };
545
546 (pool1, pool2, swap_line)
547 }
548
549 #[test]
550 fn test_swapline_fmt() {
551 let (_, _, swap_line) = default_swap_line();
552
553 let formatted = format!("{}", swap_line);
555 assert_eq!(
556 formatted,
557 "SwapLine [profit=0.02, tokens=[WETH, USDT, USDT, WETH], \
558 pools=[UniswapV2@0x4e68Ccd3E89f51C3074ca5072bbAC773960dFa36, UniswapV2@0x0d4a11d5EEaaC28EC3F61d100daF4d40471f1852], \
559 amount_in=0.01, amount_out=0.03, calculation_results=[], gas_used=Some(10000)]"
560 )
561 }
562
563 #[test]
564 fn test_contains_pool() {
565 let (pool1, pool2, swap_line) = default_swap_line();
566
567 assert!(swap_line.contains_pool(&PoolWrapper::from(pool1)));
568 assert!(swap_line.contains_pool(&PoolWrapper::from(pool2)));
569 }
570
571 #[test]
572 fn test_tokens() {
573 let (_, _, swap_line) = default_swap_line();
574
575 let tokens = swap_line.tokens();
576 assert_eq!(tokens.first().unwrap().get_address(), TokenAddressEth::WETH);
577 assert_eq!(tokens.get(1).unwrap().get_address(), TokenAddressEth::USDT);
578 }
579
580 #[test]
581 fn test_pools() {
582 let (pool1, pool2, swap_line) = default_swap_line();
583
584 let pools = swap_line.pools();
585 assert_eq!(pools.first().unwrap().get_address(), PoolId::Address(pool1.address));
586 assert_eq!(pools.get(1).unwrap().get_address(), PoolId::Address(pool2.address));
587 }
588
589 #[test]
590 fn test_get_first_token() {
591 let (_, _, swap_line) = default_swap_line();
592
593 let token = swap_line.get_first_token();
594 assert_eq!(token.unwrap().get_address(), TokenAddressEth::WETH);
595 }
596
597 #[test]
598 fn test_get_last_token() {
599 let (_, _, swap_line) = default_swap_line();
600
601 let token = swap_line.get_last_token();
602 assert_eq!(token.unwrap().get_address(), TokenAddressEth::WETH);
603 }
604
605 #[test]
606 fn test_get_first_pool() {
607 let (pool1, _, swap_line) = default_swap_line();
608
609 let pool = swap_line.get_first_pool();
610 assert_eq!(pool.unwrap().get_address(), PoolId::Address(pool1.address));
611 }
612
613 #[test]
614 fn test_get_last_pool() {
615 let (_, pool2, swap_line) = default_swap_line();
616
617 let pool = swap_line.get_last_pool();
618 assert_eq!(pool.unwrap().get_address(), PoolId::Address(pool2.address));
619 }
620}