1use crate::error::KabuDBError;
2use alloy::eips::BlockId;
3use alloy::network::primitives::HeaderResponse;
4use alloy::providers::{network::BlockResponse, Network, Provider};
5use revm::state::{AccountInfo, Bytecode};
6use revm::{
7 primitives::{Address, B256, U256},
8 Database, DatabaseRef,
9};
10use std::future::IntoFuture;
11use tokio::runtime::{Handle, Runtime};
12
13#[derive(Debug)]
14pub(crate) enum HandleOrRuntime {
15 Handle(Handle),
16 Runtime(Runtime),
17}
18
19impl HandleOrRuntime {
20 #[inline]
21 pub(crate) fn block_on<F>(&self, f: F) -> F::Output
22 where
23 F: std::future::Future + Send,
24 F::Output: Send,
25 {
26 match self {
27 Self::Handle(handle) => tokio::task::block_in_place(move || handle.block_on(f)),
28 Self::Runtime(rt) => rt.block_on(f),
29 }
30 }
31}
32
33#[derive(Debug)]
37pub struct AlloyDB<N: Network, P: Provider<N>> {
38 provider: P,
40 block_number: BlockId,
42 rt: HandleOrRuntime,
44 _marker: std::marker::PhantomData<fn() -> N>,
45}
46
47#[allow(dead_code)]
59impl<N: Network, P: Provider<N>> AlloyDB<N, P> {
60 pub fn new(provider: P, block_number: BlockId) -> Option<Self> {
64 let rt = match Handle::try_current() {
65 Ok(handle) => match handle.runtime_flavor() {
66 tokio::runtime::RuntimeFlavor::CurrentThread => return None,
67 _ => HandleOrRuntime::Handle(handle),
68 },
69 Err(_) => return None,
70 };
71 Some(Self { provider, block_number, rt, _marker: std::marker::PhantomData })
72 }
73
74 pub fn with_runtime(provider: P, block_number: BlockId, runtime: Runtime) -> Self {
79 let rt = HandleOrRuntime::Runtime(runtime);
80 Self { provider, block_number, rt, _marker: std::marker::PhantomData }
81 }
82
83 pub fn with_handle(provider: P, block_number: BlockId, handle: Handle) -> Self {
88 let rt = HandleOrRuntime::Handle(handle);
89 Self { provider, block_number, rt, _marker: std::marker::PhantomData }
90 }
91
92 #[inline]
94 fn block_on<F>(&self, f: F) -> F::Output
95 where
96 F: std::future::Future + Send,
97 F::Output: Send,
98 {
99 self.rt.block_on(f)
100 }
101
102 pub fn set_block_number(&mut self, block_number: BlockId) {
104 self.block_number = block_number;
105 }
106}
107
108impl<N: Network, P: Provider<N>> DatabaseRef for AlloyDB<N, P> {
109 type Error = KabuDBError;
110
111 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
112 let f = async {
113 let nonce = self.provider.get_transaction_count(address).block_id(self.block_number);
114 let balance = self.provider.get_balance(address).block_id(self.block_number);
115 let code = self.provider.get_code_at(address).block_id(self.block_number);
116 tokio::join!(nonce.into_future(), balance.into_future(), code.into_future())
117 };
118
119 let (nonce, balance, code) = self.block_on(f);
120
121 let balance = balance.map_err(|_| KabuDBError::TransportError)?;
122 let code = Bytecode::new_raw(code.map_err(|_| KabuDBError::TransportError)?.0.into());
123 let code_hash = code.hash_slow();
124 let nonce = nonce.map_err(|_| KabuDBError::TransportError)?;
125
126 Ok(Some(AccountInfo::new(balance, nonce, code_hash, code)))
127 }
128
129 fn code_by_hash_ref(&self, _code_hash: B256) -> Result<Bytecode, Self::Error> {
130 panic!("This should not be called, as the code is already loaded");
131 }
133
134 fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
135 let f = self.provider.get_storage_at(address, index).block_id(self.block_number);
136 let slot_val = self.block_on(f.into_future()).map_err(|_| KabuDBError::TransportError)?;
137 Ok(slot_val)
138 }
139
140 fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
141 let block = self
142 .block_on(
143 self.provider
144 .get_block_by_number(number.into())
146 .into_future(),
147 )
148 .map_err(|_| KabuDBError::TransportError)?;
149 Ok(B256::new(*block.unwrap().header().hash()))
151 }
152}
153
154impl<N: Network, P: Provider<N>> Database for AlloyDB<N, P> {
155 type Error = KabuDBError;
156
157 #[inline]
158 fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
159 <Self as DatabaseRef>::basic_ref(self, address)
160 }
161
162 #[inline]
163 fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
164 <Self as DatabaseRef>::code_by_hash_ref(self, code_hash)
165 }
166
167 #[inline]
168 fn storage(&mut self, address: Address, index: U256) -> Result<U256, Self::Error> {
169 <Self as DatabaseRef>::storage_ref(self, address, index)
170 }
171
172 #[inline]
173 fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
174 <Self as DatabaseRef>::block_hash_ref(self, number)
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use alloy::providers::ProviderBuilder;
182 use std::env;
183 use url::Url;
184
185 #[test]
186 #[ignore = "flaky RPC"]
187 fn can_get_basic() {
188 dotenvy::from_filename(".env.test").ok();
189 let node_url = env::var("MAINNET_HTTP").unwrap();
190 let node_url = Url::parse(node_url.as_str()).unwrap();
191
192 let client = ProviderBuilder::new().connect_http(node_url);
193 let alloydb = AlloyDB::new(client, BlockId::from(16148323));
194
195 if alloydb.is_none() {
196 println!("Alloydb is None");
197 }
198
199 let address: Address = "0x0d4a11d5EEaaC28EC3F61d100daF4d40471f1852".parse().unwrap();
201
202 let acc_info = alloydb.unwrap().basic_ref(address).unwrap().unwrap();
203 assert!(acc_info.exists());
204 }
205}