1use crate::{PoolId, PoolWrapper, SwapDirection, Token};
2use alloy_primitives::{map::HashMap, Address};
3use eyre::Result;
4use std::fmt;
5use std::fmt::Display;
6use std::hash::{DefaultHasher, Hash, Hasher};
7use std::sync::Arc;
8use tracing::debug;
9
10#[derive(Clone, Debug)]
11pub struct SwapPath {
12 pub tokens: Vec<Arc<Token>>,
13 pub pools: Vec<PoolWrapper>,
14 pub disabled: bool,
15 pub disabled_pool: Vec<PoolId>,
16 pub score: Option<f64>,
17}
18
19impl Display for SwapPath {
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 let tokens = self.tokens.iter().map(|token| token.get_symbol()).collect::<Vec<String>>().join(", ");
22 let pools =
23 self.pools.iter().map(|pool| format!("{}@{}", pool.get_protocol(), pool.get_pool_id())).collect::<Vec<String>>().join(", ");
24
25 write!(f, "SwapPath [tokens=[{}], pools=[{}] disabled={}]", tokens, pools, self.disabled)
26 }
27}
28
29impl Default for SwapPath {
30 #[inline]
31 fn default() -> Self {
32 SwapPath { tokens: Vec::new(), pools: Vec::new(), disabled: false, disabled_pool: Default::default(), score: None }
33 }
34}
35
36impl PartialEq for SwapPath {
37 fn eq(&self, other: &Self) -> bool {
38 self.tokens == other.tokens && self.pools == other.pools
39 }
40}
41
42impl Eq for SwapPath {}
43
44impl Hash for SwapPath {
45 fn hash<H: Hasher>(&self, state: &mut H) {
46 self.tokens.hash(state);
47 self.pools.hash(state);
48 }
49}
50
51impl SwapPath {
52 #[inline]
53 pub fn new<T: Into<Arc<Token>>, P: Into<PoolWrapper>>(tokens: Vec<T>, pools: Vec<P>) -> Self {
54 SwapPath {
55 tokens: tokens.into_iter().map(|i| i.into()).collect(),
56 pools: pools.into_iter().map(|i| i.into()).collect(),
57 disabled: false,
58 disabled_pool: Default::default(),
59 score: None,
60 }
61 }
62
63 #[inline]
64 pub fn is_emply(&self) -> bool {
65 self.tokens.is_empty() && self.pools.is_empty()
66 }
67
68 #[inline]
69 pub fn tokens_count(&self) -> usize {
70 self.tokens.len()
71 }
72
73 #[inline]
74 pub fn pool_count(&self) -> usize {
75 self.pools.len()
76 }
77
78 #[inline]
79 pub fn new_swap(token_from: Arc<Token>, token_to: Arc<Token>, pool: PoolWrapper) -> Self {
80 SwapPath { tokens: vec![token_from, token_to], pools: vec![pool], disabled: false, disabled_pool: Default::default(), score: None }
81 }
82
83 #[inline]
84 pub fn push_swap_hope(&mut self, token_from: Arc<Token>, token_to: Arc<Token>, pool: PoolWrapper) -> Result<&mut Self> {
85 if self.is_emply() {
86 self.tokens = vec![token_from, token_to];
87 self.pools = vec![pool];
88 } else {
89 if token_from.as_ref() != self.tokens.last().map_or(&Token::zero(), |t| t.as_ref()) {
90 return Err(eyre::eyre!("NEW_SWAP_NOT_CONNECTED"));
91 }
92 self.tokens.push(token_to);
93 self.pools.push(pool);
94 }
95 Ok(self)
96 }
97
98 #[inline]
99 pub fn insert_swap_hope(&mut self, token_from: Arc<Token>, token_to: Arc<Token>, pool: PoolWrapper) -> Result<&mut Self> {
100 if self.is_emply() {
101 self.tokens = vec![token_from, token_to];
102 self.pools = vec![pool];
103 } else {
104 if token_to.as_ref() != self.tokens.first().map_or(&Token::zero(), |t| t.as_ref()) {
105 return Err(eyre::eyre!("NEW_SWAP_NOT_CONNECTED"));
106 }
107 self.tokens.insert(0, token_from);
108 self.pools.insert(0, pool);
109 }
110
111 Ok(self)
112 }
113
114 #[inline]
115 pub fn contains_pool(&self, pool: &PoolWrapper) -> bool {
116 for p in self.pools.iter() {
117 if p == pool {
118 return true;
119 }
120 }
121 false
122 }
123
124 #[inline]
125 pub fn get_hash(&self) -> u64 {
126 let mut h = DefaultHasher::new();
127 self.hash(&mut h);
128 h.finish()
129 }
130}
131
132#[derive(Clone, Debug, Default)]
133pub struct SwapPaths {
134 pub paths: Vec<SwapPath>,
135 pub pool_paths: HashMap<PoolId, Vec<usize>>,
136 pub path_hash_map: HashMap<u64, usize>,
137 pub disabled_directions: HashMap<u64, bool>,
138}
139
140impl SwapPaths {
141 pub fn new() -> SwapPaths {
142 SwapPaths {
143 paths: Vec::new(),
144 pool_paths: HashMap::default(),
145 path_hash_map: HashMap::default(),
146 disabled_directions: HashMap::default(),
147 }
148 }
149 pub fn from(paths: Vec<SwapPath>) -> Self {
150 let mut swap_paths_ret = SwapPaths::new();
151 for p in paths {
152 swap_paths_ret.add(p);
153 }
154 swap_paths_ret
155 }
156
157 pub fn len(&self) -> usize {
158 self.paths.len()
159 }
160
161 pub fn disabled_len(&self) -> usize {
162 self.paths.iter().filter(|p| p.disabled).count()
163 }
164
165 pub fn is_empty(&self) -> bool {
166 self.paths.is_empty()
167 }
168
169 pub fn len_max(&self) -> usize {
170 self.pool_paths.values().map(|item| item.len()).max().unwrap_or_default()
171 }
172
173 #[inline]
174 pub fn add(&mut self, path: SwapPath) -> Option<usize> {
175 let path_hash = path.get_hash();
176 let path_idx = self.paths.len();
177
178 match self.path_hash_map.entry(path_hash) {
179 std::collections::hash_map::Entry::Occupied(_) => {
180 None
182 }
183 std::collections::hash_map::Entry::Vacant(e) => {
184 e.insert(path_idx);
186
187 for pool in &path.pools {
188 self.pool_paths.entry(pool.get_pool_id()).or_default().push(path_idx);
189 }
190
191 self.paths.push(path);
192 Some(path_idx)
193 }
194 }
195 }
196
197 pub fn disable_path(&mut self, swap_path: &SwapPath, disable: bool) -> bool {
198 if let Some(swap_path_idx) = self.path_hash_map.get(&swap_path.get_hash()) {
199 if let Some(swap_path) = self.paths.get_mut(*swap_path_idx) {
200 debug!("Path disabled hash={}, path={}", swap_path.get_hash(), swap_path);
201 swap_path.disabled = disable;
202 return true;
203 }
204 }
205 debug!("Path not disabled hash={}, path={}", swap_path.get_hash(), swap_path);
206 false
207 }
208
209 pub fn disable_pool_paths(&mut self, pool_id: &PoolId, token_from_address: &Address, token_to_address: &Address, disabled: bool) {
210 let Some(pool_paths) = self.pool_paths.get(pool_id).cloned() else { return };
211
212 for path_idx in pool_paths.iter() {
213 if let Some(entry) = self.paths.get_mut(*path_idx) {
214 if let Some(idx) = entry.pools.iter().position(|item| item.get_pool_id().eq(pool_id)) {
215 if let (Some(token_from), Some(token_to)) = (entry.tokens.get(idx), entry.tokens.get(idx + 1)) {
216 if token_from.get_address().eq(token_from_address) && token_to.get_address().eq(token_to_address) {
217 entry.disabled = disabled;
218 if !entry.disabled_pool.contains(pool_id) {
219 entry.disabled_pool.push(*pool_id);
220 }
221 self.disabled_directions
222 .insert(SwapDirection::new(*token_from_address, *token_to_address).get_hash_with_pool(pool_id), disabled);
223 }
224 }
225 } else {
226 entry.disabled = disabled;
228 if !entry.disabled_pool.contains(pool_id) {
229 entry.disabled_pool.push(*pool_id);
230 }
231 }
232 }
233 }
234 }
235 #[inline]
236 pub fn get_pool_paths_enabled_vec(&self, pool_id: &PoolId) -> Option<Vec<SwapPath>> {
237 let paths = self.pool_paths.get(pool_id)?;
238 let paths_vec_ret: Vec<SwapPath> = paths
239 .iter()
240 .filter_map(|a| {
241 self.paths
242 .get(*a)
243 .filter(|a| a.disabled_pool.is_empty() || (a.disabled_pool.len() == 1 && a.disabled_pool.contains(pool_id)))
244 })
245 .cloned()
246 .collect();
247 (!paths_vec_ret.is_empty()).then_some(paths_vec_ret)
248 }
249
250 #[inline]
251 pub fn get_path_by_idx(&self, idx: usize) -> Option<&SwapPath> {
252 self.paths.get(idx)
253 }
254
255 #[inline]
256 pub fn get_path_by_idx_mut(&mut self, idx: usize) -> Option<&mut SwapPath> {
257 self.paths.get_mut(idx)
258 }
259
260 #[inline]
261 pub fn get_path_by_hash(&self, idx: u64) -> Option<&SwapPath> {
262 self.path_hash_map.get(&idx).and_then(|i| self.paths.get(*i))
263 }
264}
265
266#[cfg(test)]
267mod test {
268 use super::*;
269 use crate::pool::DefaultAbiSwapEncoder;
270 use crate::required_state::RequiredState;
271 use crate::{Pool, PoolAbiEncoder, PoolClass, PoolProtocol, PreswapRequirement, SwapDirection};
272 use alloy_primitives::{Address, U256};
273 use eyre::{eyre, ErrReport};
274 use kabu_evm_db::KabuDBError;
275 use revm::DatabaseRef;
276 use std::any::Any;
277 use tokio::task::JoinHandle;
278 use tracing::error;
279
280 #[derive(Clone)]
281 pub struct EmptyPool {
282 address: Address,
283 }
284
285 impl EmptyPool {
286 pub fn new(address: Address) -> Self {
287 EmptyPool { address }
288 }
289 }
290
291 impl Pool for EmptyPool {
292 fn as_any<'a>(&self) -> &dyn Any {
293 self
294 }
295
296 fn is_native(&self) -> bool {
297 false
298 }
299 fn get_address(&self) -> PoolId {
300 PoolId::Address(self.address)
301 }
302
303 fn get_pool_id(&self) -> PoolId {
304 PoolId::Address(self.address)
305 }
306
307 fn calculate_out_amount(
308 &self,
309 _db: &dyn DatabaseRef<Error = KabuDBError>,
310 _token_address_from: &Address,
311 _token_address_to: &Address,
312 _in_amount: U256,
313 ) -> Result<(U256, u64), ErrReport> {
314 Err(eyre!("NOT_IMPLEMENTED"))
315 }
316
317 fn calculate_in_amount(
318 &self,
319 _db: &dyn DatabaseRef<Error = KabuDBError>,
320 _token_address_from: &Address,
321 _token_address_to: &Address,
322 _out_amount: U256,
323 ) -> eyre::Result<(U256, u64), ErrReport> {
324 Err(eyre!("NOT_IMPLEMENTED"))
325 }
326
327 fn can_flash_swap(&self) -> bool {
328 false
329 }
330
331 fn get_abi_encoder(&self) -> Option<&dyn PoolAbiEncoder> {
332 Some(&DefaultAbiSwapEncoder {})
333 }
334
335 fn get_state_required(&self) -> Result<RequiredState> {
336 Ok(RequiredState::new())
337 }
338
339 fn get_class(&self) -> PoolClass {
340 PoolClass::Unknown
341 }
342
343 fn get_protocol(&self) -> PoolProtocol {
344 PoolProtocol::Unknown
345 }
346
347 fn get_fee(&self) -> U256 {
348 U256::ZERO
349 }
350
351 fn get_tokens(&self) -> Vec<Address> {
352 vec![]
353 }
354
355 fn get_swap_directions(&self) -> Vec<SwapDirection> {
356 vec![]
357 }
358
359 fn can_calculate_in_amount(&self) -> bool {
360 true
361 }
362
363 fn get_read_only_cell_vec(&self) -> Vec<U256> {
364 vec![]
365 }
366
367 fn preswap_requirement(&self) -> PreswapRequirement {
368 PreswapRequirement::Base
369 }
370 }
371
372 #[test]
373 fn test_add_path() {
374 let basic_token = Token::new(Address::repeat_byte(0x11));
375
376 let paths_vec: Vec<SwapPath> = (0..10)
377 .map(|i| {
378 SwapPath::new(
379 vec![basic_token.clone(), Token::new(Address::repeat_byte(i)), basic_token.clone()],
380 vec![
381 PoolWrapper::new(Arc::new(EmptyPool::new(Address::repeat_byte(i + 1)))),
382 PoolWrapper::new(Arc::new(EmptyPool::new(Address::repeat_byte(i + 2)))),
383 ],
384 )
385 })
386 .collect();
387 let paths = SwapPaths::from(paths_vec);
388
389 println!("{paths:?}")
390 }
391
392 #[tokio::test]
393 async fn async_test() {
394 let basic_token = Token::new(Address::repeat_byte(0x11));
395
396 const PATHS_COUNT: usize = 10;
397
398 let pool_address_vec: Vec<(PoolWrapper, PoolWrapper)> = (0..PATHS_COUNT)
399 .map(|i| {
400 (
401 PoolWrapper::new(Arc::new(EmptyPool::new(Address::repeat_byte(i as u8)))),
402 PoolWrapper::new(Arc::new(EmptyPool::new(Address::repeat_byte((i + 1) as u8)))),
403 )
404 })
405 .collect();
406
407 let paths_vec: Vec<SwapPath> = pool_address_vec
408 .iter()
409 .map(|p| {
410 SwapPath::new(
411 vec![basic_token.clone(), Token::new(Address::repeat_byte(1)), basic_token.clone()],
412 vec![p.0.clone(), p.1.clone()],
413 )
414 })
415 .collect();
416
417 let mut paths = SwapPaths::from(paths_vec.clone());
418 for path in paths_vec.clone() {
419 paths.add(path);
420 }
421
422 let paths_shared = Arc::new(tokio::sync::RwLock::new(paths));
423
424 let mut tasks: Vec<JoinHandle<_>> = Vec::new();
425
426 for (i, pools) in pool_address_vec.into_iter().enumerate() {
427 let pool_id = pools.0.get_pool_id();
428 let paths_shared_clone = paths_shared.clone();
429 tasks.push(tokio::task::spawn(async move {
430 let address = match pool_id {
431 PoolId::Address(addr) => addr,
432 PoolId::B256(_) => Address::default(),
433 };
434 let pool = PoolWrapper::new(Arc::new(EmptyPool::new(address)));
435 let path_guard = paths_shared_clone.read().await;
436 let pool_paths = path_guard.get_pool_paths_enabled_vec(&pool.get_pool_id());
437 println!("{i} {pool_id}: {pool_paths:?}");
438 }));
439 }
440
441 for t in tasks {
442 if let Err(e) = t.await {
443 error!("{}", e)
444 }
445 }
446 }
447
448 #[test]
449 fn test_disable_path() {
450 let basic_token = Token::new(Address::repeat_byte(0x11));
451
452 let paths_vec: Vec<SwapPath> = (0..10)
453 .map(|i| {
454 SwapPath::new(
455 vec![basic_token.clone(), Token::new(Address::repeat_byte(i)), basic_token.clone()],
456 vec![
457 PoolWrapper::new(Arc::new(EmptyPool::new(Address::repeat_byte(1)))),
458 PoolWrapper::new(Arc::new(EmptyPool::new(Address::repeat_byte(i + 2)))),
459 ],
460 )
461 })
462 .collect();
463 let disabled_paths = paths_vec[0].clone();
464 let mut paths = SwapPaths::from(paths_vec);
465 println!("Paths : {paths:?}");
466
467 paths.disable_path(&disabled_paths, true);
468
469 let pool_paths = paths.get_pool_paths_enabled_vec(&disabled_paths.pools[0].get_pool_id());
470
471 println!("Pool paths : {pool_paths:?}");
472 }
473}