kabu_types_entities/
swap_path.rs

1use crate::{PoolId, PoolWrapper, SwapDirection, Token};
2use alloy_primitives::{map::HashMap, Address};
3use eyre::Result;
4use std::fmt;
5use std::fmt::Display;
6use std::hash::{DefaultHasher, Hash, Hasher};
7use std::sync::Arc;
8use tracing::debug;
9
10#[derive(Clone, Debug)]
11pub struct SwapPath {
12    pub tokens: Vec<Arc<Token>>,
13    pub pools: Vec<PoolWrapper>,
14    pub disabled: bool,
15    pub disabled_pool: Vec<PoolId>,
16    pub score: Option<f64>,
17}
18
19impl Display for SwapPath {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        let tokens = self.tokens.iter().map(|token| token.get_symbol()).collect::<Vec<String>>().join(", ");
22        let pools =
23            self.pools.iter().map(|pool| format!("{}@{}", pool.get_protocol(), pool.get_pool_id())).collect::<Vec<String>>().join(", ");
24
25        write!(f, "SwapPath [tokens=[{}], pools=[{}] disabled={}]", tokens, pools, self.disabled)
26    }
27}
28
29impl Default for SwapPath {
30    #[inline]
31    fn default() -> Self {
32        SwapPath { tokens: Vec::new(), pools: Vec::new(), disabled: false, disabled_pool: Default::default(), score: None }
33    }
34}
35
36impl PartialEq for SwapPath {
37    fn eq(&self, other: &Self) -> bool {
38        self.tokens == other.tokens && self.pools == other.pools
39    }
40}
41
42impl Eq for SwapPath {}
43
44impl Hash for SwapPath {
45    fn hash<H: Hasher>(&self, state: &mut H) {
46        self.tokens.hash(state);
47        self.pools.hash(state);
48    }
49}
50
51impl SwapPath {
52    #[inline]
53    pub fn new<T: Into<Arc<Token>>, P: Into<PoolWrapper>>(tokens: Vec<T>, pools: Vec<P>) -> Self {
54        SwapPath {
55            tokens: tokens.into_iter().map(|i| i.into()).collect(),
56            pools: pools.into_iter().map(|i| i.into()).collect(),
57            disabled: false,
58            disabled_pool: Default::default(),
59            score: None,
60        }
61    }
62
63    #[inline]
64    pub fn is_emply(&self) -> bool {
65        self.tokens.is_empty() && self.pools.is_empty()
66    }
67
68    #[inline]
69    pub fn tokens_count(&self) -> usize {
70        self.tokens.len()
71    }
72
73    #[inline]
74    pub fn pool_count(&self) -> usize {
75        self.pools.len()
76    }
77
78    #[inline]
79    pub fn new_swap(token_from: Arc<Token>, token_to: Arc<Token>, pool: PoolWrapper) -> Self {
80        SwapPath { tokens: vec![token_from, token_to], pools: vec![pool], disabled: false, disabled_pool: Default::default(), score: None }
81    }
82
83    #[inline]
84    pub fn push_swap_hope(&mut self, token_from: Arc<Token>, token_to: Arc<Token>, pool: PoolWrapper) -> Result<&mut Self> {
85        if self.is_emply() {
86            self.tokens = vec![token_from, token_to];
87            self.pools = vec![pool];
88        } else {
89            if token_from.as_ref() != self.tokens.last().map_or(&Token::zero(), |t| t.as_ref()) {
90                return Err(eyre::eyre!("NEW_SWAP_NOT_CONNECTED"));
91            }
92            self.tokens.push(token_to);
93            self.pools.push(pool);
94        }
95        Ok(self)
96    }
97
98    #[inline]
99    pub fn insert_swap_hope(&mut self, token_from: Arc<Token>, token_to: Arc<Token>, pool: PoolWrapper) -> Result<&mut Self> {
100        if self.is_emply() {
101            self.tokens = vec![token_from, token_to];
102            self.pools = vec![pool];
103        } else {
104            if token_to.as_ref() != self.tokens.first().map_or(&Token::zero(), |t| t.as_ref()) {
105                return Err(eyre::eyre!("NEW_SWAP_NOT_CONNECTED"));
106            }
107            self.tokens.insert(0, token_from);
108            self.pools.insert(0, pool);
109        }
110
111        Ok(self)
112    }
113
114    #[inline]
115    pub fn contains_pool(&self, pool: &PoolWrapper) -> bool {
116        for p in self.pools.iter() {
117            if p == pool {
118                return true;
119            }
120        }
121        false
122    }
123
124    #[inline]
125    pub fn get_hash(&self) -> u64 {
126        let mut h = DefaultHasher::new();
127        self.hash(&mut h);
128        h.finish()
129    }
130}
131
132#[derive(Clone, Debug, Default)]
133pub struct SwapPaths {
134    pub paths: Vec<SwapPath>,
135    pub pool_paths: HashMap<PoolId, Vec<usize>>,
136    pub path_hash_map: HashMap<u64, usize>,
137    pub disabled_directions: HashMap<u64, bool>,
138}
139
140impl SwapPaths {
141    pub fn new() -> SwapPaths {
142        SwapPaths {
143            paths: Vec::new(),
144            pool_paths: HashMap::default(),
145            path_hash_map: HashMap::default(),
146            disabled_directions: HashMap::default(),
147        }
148    }
149    pub fn from(paths: Vec<SwapPath>) -> Self {
150        let mut swap_paths_ret = SwapPaths::new();
151        for p in paths {
152            swap_paths_ret.add(p);
153        }
154        swap_paths_ret
155    }
156
157    pub fn len(&self) -> usize {
158        self.paths.len()
159    }
160
161    pub fn disabled_len(&self) -> usize {
162        self.paths.iter().filter(|p| p.disabled).count()
163    }
164
165    pub fn is_empty(&self) -> bool {
166        self.paths.is_empty()
167    }
168
169    pub fn len_max(&self) -> usize {
170        self.pool_paths.values().map(|item| item.len()).max().unwrap_or_default()
171    }
172
173    #[inline]
174    pub fn add(&mut self, path: SwapPath) -> Option<usize> {
175        let path_hash = path.get_hash();
176        let path_idx = self.paths.len();
177
178        match self.path_hash_map.entry(path_hash) {
179            std::collections::hash_map::Entry::Occupied(_) => {
180                //debug!("Path already exists hash={}, path={}", path.get_hash(), path);
181                None
182            }
183            std::collections::hash_map::Entry::Vacant(e) => {
184                //debug!("Path added hash={}, path={}", path.get_hash(), path);
185                e.insert(path_idx);
186
187                for pool in &path.pools {
188                    self.pool_paths.entry(pool.get_pool_id()).or_default().push(path_idx);
189                }
190
191                self.paths.push(path);
192                Some(path_idx)
193            }
194        }
195    }
196
197    pub fn disable_path(&mut self, swap_path: &SwapPath, disable: bool) -> bool {
198        if let Some(swap_path_idx) = self.path_hash_map.get(&swap_path.get_hash()) {
199            if let Some(swap_path) = self.paths.get_mut(*swap_path_idx) {
200                debug!("Path disabled hash={}, path={}", swap_path.get_hash(), swap_path);
201                swap_path.disabled = disable;
202                return true;
203            }
204        }
205        debug!("Path not disabled hash={}, path={}", swap_path.get_hash(), swap_path);
206        false
207    }
208
209    pub fn disable_pool_paths(&mut self, pool_id: &PoolId, token_from_address: &Address, token_to_address: &Address, disabled: bool) {
210        let Some(pool_paths) = self.pool_paths.get(pool_id).cloned() else { return };
211
212        for path_idx in pool_paths.iter() {
213            if let Some(entry) = self.paths.get_mut(*path_idx) {
214                if let Some(idx) = entry.pools.iter().position(|item| item.get_pool_id().eq(pool_id)) {
215                    if let (Some(token_from), Some(token_to)) = (entry.tokens.get(idx), entry.tokens.get(idx + 1)) {
216                        if token_from.get_address().eq(token_from_address) && token_to.get_address().eq(token_to_address) {
217                            entry.disabled = disabled;
218                            if !entry.disabled_pool.contains(pool_id) {
219                                entry.disabled_pool.push(*pool_id);
220                            }
221                            self.disabled_directions
222                                .insert(SwapDirection::new(*token_from_address, *token_to_address).get_hash_with_pool(pool_id), disabled);
223                        }
224                    }
225                } else {
226                    //debug!("All path disabled by pool hash={}, path={}", entry.get_hash(), entry);
227                    entry.disabled = disabled;
228                    if !entry.disabled_pool.contains(pool_id) {
229                        entry.disabled_pool.push(*pool_id);
230                    }
231                }
232            }
233        }
234    }
235    #[inline]
236    pub fn get_pool_paths_enabled_vec(&self, pool_id: &PoolId) -> Option<Vec<SwapPath>> {
237        let paths = self.pool_paths.get(pool_id)?;
238        let paths_vec_ret: Vec<SwapPath> = paths
239            .iter()
240            .filter_map(|a| {
241                self.paths
242                    .get(*a)
243                    .filter(|a| a.disabled_pool.is_empty() || (a.disabled_pool.len() == 1 && a.disabled_pool.contains(pool_id)))
244            })
245            .cloned()
246            .collect();
247        (!paths_vec_ret.is_empty()).then_some(paths_vec_ret)
248    }
249
250    #[inline]
251    pub fn get_path_by_idx(&self, idx: usize) -> Option<&SwapPath> {
252        self.paths.get(idx)
253    }
254
255    #[inline]
256    pub fn get_path_by_idx_mut(&mut self, idx: usize) -> Option<&mut SwapPath> {
257        self.paths.get_mut(idx)
258    }
259
260    #[inline]
261    pub fn get_path_by_hash(&self, idx: u64) -> Option<&SwapPath> {
262        self.path_hash_map.get(&idx).and_then(|i| self.paths.get(*i))
263    }
264}
265
266#[cfg(test)]
267mod test {
268    use super::*;
269    use crate::pool::DefaultAbiSwapEncoder;
270    use crate::required_state::RequiredState;
271    use crate::{Pool, PoolAbiEncoder, PoolClass, PoolProtocol, PreswapRequirement, SwapDirection};
272    use alloy_primitives::{Address, U256};
273    use eyre::{eyre, ErrReport};
274    use kabu_evm_db::KabuDBError;
275    use revm::DatabaseRef;
276    use std::any::Any;
277    use tokio::task::JoinHandle;
278    use tracing::error;
279
280    #[derive(Clone)]
281    pub struct EmptyPool {
282        address: Address,
283    }
284
285    impl EmptyPool {
286        pub fn new(address: Address) -> Self {
287            EmptyPool { address }
288        }
289    }
290
291    impl Pool for EmptyPool {
292        fn as_any<'a>(&self) -> &dyn Any {
293            self
294        }
295
296        fn is_native(&self) -> bool {
297            false
298        }
299        fn get_address(&self) -> PoolId {
300            PoolId::Address(self.address)
301        }
302
303        fn get_pool_id(&self) -> PoolId {
304            PoolId::Address(self.address)
305        }
306
307        fn calculate_out_amount(
308            &self,
309            _db: &dyn DatabaseRef<Error = KabuDBError>,
310            _token_address_from: &Address,
311            _token_address_to: &Address,
312            _in_amount: U256,
313        ) -> Result<(U256, u64), ErrReport> {
314            Err(eyre!("NOT_IMPLEMENTED"))
315        }
316
317        fn calculate_in_amount(
318            &self,
319            _db: &dyn DatabaseRef<Error = KabuDBError>,
320            _token_address_from: &Address,
321            _token_address_to: &Address,
322            _out_amount: U256,
323        ) -> eyre::Result<(U256, u64), ErrReport> {
324            Err(eyre!("NOT_IMPLEMENTED"))
325        }
326
327        fn can_flash_swap(&self) -> bool {
328            false
329        }
330
331        fn get_abi_encoder(&self) -> Option<&dyn PoolAbiEncoder> {
332            Some(&DefaultAbiSwapEncoder {})
333        }
334
335        fn get_state_required(&self) -> Result<RequiredState> {
336            Ok(RequiredState::new())
337        }
338
339        fn get_class(&self) -> PoolClass {
340            PoolClass::Unknown
341        }
342
343        fn get_protocol(&self) -> PoolProtocol {
344            PoolProtocol::Unknown
345        }
346
347        fn get_fee(&self) -> U256 {
348            U256::ZERO
349        }
350
351        fn get_tokens(&self) -> Vec<Address> {
352            vec![]
353        }
354
355        fn get_swap_directions(&self) -> Vec<SwapDirection> {
356            vec![]
357        }
358
359        fn can_calculate_in_amount(&self) -> bool {
360            true
361        }
362
363        fn get_read_only_cell_vec(&self) -> Vec<U256> {
364            vec![]
365        }
366
367        fn preswap_requirement(&self) -> PreswapRequirement {
368            PreswapRequirement::Base
369        }
370    }
371
372    #[test]
373    fn test_add_path() {
374        let basic_token = Token::new(Address::repeat_byte(0x11));
375
376        let paths_vec: Vec<SwapPath> = (0..10)
377            .map(|i| {
378                SwapPath::new(
379                    vec![basic_token.clone(), Token::new(Address::repeat_byte(i)), basic_token.clone()],
380                    vec![
381                        PoolWrapper::new(Arc::new(EmptyPool::new(Address::repeat_byte(i + 1)))),
382                        PoolWrapper::new(Arc::new(EmptyPool::new(Address::repeat_byte(i + 2)))),
383                    ],
384                )
385            })
386            .collect();
387        let paths = SwapPaths::from(paths_vec);
388
389        println!("{paths:?}")
390    }
391
392    #[tokio::test]
393    async fn async_test() {
394        let basic_token = Token::new(Address::repeat_byte(0x11));
395
396        const PATHS_COUNT: usize = 10;
397
398        let pool_address_vec: Vec<(PoolWrapper, PoolWrapper)> = (0..PATHS_COUNT)
399            .map(|i| {
400                (
401                    PoolWrapper::new(Arc::new(EmptyPool::new(Address::repeat_byte(i as u8)))),
402                    PoolWrapper::new(Arc::new(EmptyPool::new(Address::repeat_byte((i + 1) as u8)))),
403                )
404            })
405            .collect();
406
407        let paths_vec: Vec<SwapPath> = pool_address_vec
408            .iter()
409            .map(|p| {
410                SwapPath::new(
411                    vec![basic_token.clone(), Token::new(Address::repeat_byte(1)), basic_token.clone()],
412                    vec![p.0.clone(), p.1.clone()],
413                )
414            })
415            .collect();
416
417        let mut paths = SwapPaths::from(paths_vec.clone());
418        for path in paths_vec.clone() {
419            paths.add(path);
420        }
421
422        let paths_shared = Arc::new(tokio::sync::RwLock::new(paths));
423
424        let mut tasks: Vec<JoinHandle<_>> = Vec::new();
425
426        for (i, pools) in pool_address_vec.into_iter().enumerate() {
427            let pool_id = pools.0.get_pool_id();
428            let paths_shared_clone = paths_shared.clone();
429            tasks.push(tokio::task::spawn(async move {
430                let address = match pool_id {
431                    PoolId::Address(addr) => addr,
432                    PoolId::B256(_) => Address::default(),
433                };
434                let pool = PoolWrapper::new(Arc::new(EmptyPool::new(address)));
435                let path_guard = paths_shared_clone.read().await;
436                let pool_paths = path_guard.get_pool_paths_enabled_vec(&pool.get_pool_id());
437                println!("{i} {pool_id}: {pool_paths:?}");
438            }));
439        }
440
441        for t in tasks {
442            if let Err(e) = t.await {
443                error!("{}", e)
444            }
445        }
446    }
447
448    #[test]
449    fn test_disable_path() {
450        let basic_token = Token::new(Address::repeat_byte(0x11));
451
452        let paths_vec: Vec<SwapPath> = (0..10)
453            .map(|i| {
454                SwapPath::new(
455                    vec![basic_token.clone(), Token::new(Address::repeat_byte(i)), basic_token.clone()],
456                    vec![
457                        PoolWrapper::new(Arc::new(EmptyPool::new(Address::repeat_byte(1)))),
458                        PoolWrapper::new(Arc::new(EmptyPool::new(Address::repeat_byte(i + 2)))),
459                    ],
460                )
461            })
462            .collect();
463        let disabled_paths = paths_vec[0].clone();
464        let mut paths = SwapPaths::from(paths_vec);
465        println!("Paths : {paths:?}");
466
467        paths.disable_path(&disabled_paths, true);
468
469        let pool_paths = paths.get_pool_paths_enabled_vec(&disabled_paths.pools[0].get_pool_id());
470
471        println!("Pool paths : {pool_paths:?}");
472    }
473}