1use alloy_primitives::{Address, TxHash};
2use alloy_provider::ext::DebugApi;
3use alloy_provider::{Network, Provider};
4use alloy_rpc_types::{BlockId, TransactionRequest};
5use alloy_rpc_types_trace::common::TraceResult;
6use alloy_rpc_types_trace::geth::GethDebugBuiltInTracerType::PreStateTracer;
7use alloy_rpc_types_trace::geth::GethDebugTracerType::BuiltInTracer;
8use alloy_rpc_types_trace::geth::{
9 AccountState, GethDebugBuiltInTracerType, GethDebugTracerConfig, GethDebugTracerType, GethDebugTracingCallOptions,
10 GethDebugTracingOptions, GethDefaultTracingOptions, GethTrace, PreStateConfig, PreStateFrame,
11};
12use eyre::Result;
13use lazy_static::lazy_static;
14use std::collections::BTreeMap;
15use tracing::{debug, trace};
16
17use kabu_node_debug_provider::DebugProviderExt;
18
19pub type GethStateUpdate = BTreeMap<Address, AccountState>;
20
21pub type GethStateUpdateVec = Vec<BTreeMap<Address, AccountState>>;
22
23lazy_static! {
24 pub static ref TRACING_OPTS: GethDebugTracingOptions = GethDebugTracingOptions {
25 tracer: Some(GethDebugTracerType::BuiltInTracer(GethDebugBuiltInTracerType::PreStateTracer,)),
26 tracer_config: GethDebugTracerConfig::default(),
27 config: GethDefaultTracingOptions::default().disable_storage().disable_stack().disable_memory().disable_return_data(),
28 timeout: None,
29 };
30 pub static ref TRACING_CALL_OPTS: GethDebugTracingCallOptions =
31 GethDebugTracingCallOptions { tracing_options: TRACING_OPTS.clone(), state_overrides: None, block_overrides: None };
32}
33
34pub fn get_touched_addresses(state_update: &GethStateUpdate) -> Vec<Address> {
35 let mut ret: Vec<Address> = Vec::new();
36
37 for (address, state) in state_update.iter() {
38 if !state.storage.is_empty() {
39 ret.push(*address)
40 }
41 }
42
43 ret
44}
45
46pub fn debug_log_geth_state_update(state_update: &GethStateUpdate) {
47 for (address, state) in state_update {
48 debug!("{} nonce {:?} balance {:?} is_code {}", address, state.nonce, state.balance, state.code.is_some())
49 }
50}
51
52pub async fn debug_trace_block<N: Network, P: Provider<N> + DebugProviderExt<N>>(
53 client: P,
54 block_id: BlockId,
55 diff_mode: bool,
56) -> eyre::Result<(GethStateUpdateVec, GethStateUpdateVec)> {
57 let tracer_opts = GethDebugTracingOptions { config: GethDefaultTracingOptions::default(), ..GethDebugTracingOptions::default() }
58 .with_tracer(BuiltInTracer(PreStateTracer))
59 .with_prestate_config(PreStateConfig { diff_mode: Some(diff_mode), disable_code: Some(false), disable_storage: Some(false) });
60
61 let trace_result_vec = match block_id {
62 BlockId::Number(block_number) => client.geth_debug_trace_block_by_number(block_number, tracer_opts).await?,
63 BlockId::Hash(rpc_block_hash) => {
64 client.geth_debug_trace_block_by_hash(rpc_block_hash.block_hash, tracer_opts).await?
66 }
67 };
68
69 trace!("block trace {}", trace_result_vec.len());
70
71 let mut pre: GethStateUpdateVec = Default::default();
72 let mut post: GethStateUpdateVec = Default::default();
73
74 for trace_result in trace_result_vec.into_iter() {
75 if let TraceResult::Success { result, .. } = trace_result {
76 match result {
77 GethTrace::PreStateTracer(geth_trace_frame) => match geth_trace_frame {
78 PreStateFrame::Diff(diff_frame) => {
79 pre.push(diff_frame.pre);
80 post.push(diff_frame.post);
81 }
82 PreStateFrame::Default(diff_frame) => {
83 pre.push(diff_frame.0.into_iter().collect());
84 }
85 },
86 _ => {
87 return Err(eyre::eyre!("TRACE_RESULT_FAILED"));
88 }
89 }
90 }
91 }
92 Ok((pre, post))
93}
94
95async fn debug_trace_call<N: Network, C: DebugProviderExt<N>, TR: Into<N::TransactionRequest> + Send + Sync>(
96 client: C,
97 req: TR,
98 block: BlockId,
99 opts: Option<GethDebugTracingCallOptions>,
100 diff_mode: bool,
101) -> Result<(GethStateUpdate, GethStateUpdate)> {
102 let tracer_opts = GethDebugTracingOptions { config: GethDefaultTracingOptions::default(), ..GethDebugTracingOptions::default() }
103 .with_tracer(BuiltInTracer(PreStateTracer))
104 .with_prestate_config(PreStateConfig { diff_mode: Some(diff_mode), disable_code: Some(false), disable_storage: Some(false) });
105
106 let tracer_call_opts = GethDebugTracingCallOptions {
107 tracing_options: tracer_opts.clone(),
108 state_overrides: opts.clone().and_then(|x| x.state_overrides),
109 block_overrides: opts.and_then(|x| x.block_overrides),
110 };
111
112 let trace_result = client.geth_debug_trace_call(req.into(), block, tracer_call_opts.clone()).await?;
113 trace!(
114 "{} {} {:?} {:?}",
115 tracer_opts.config.is_stack_enabled(),
116 tracer_opts.config.is_storage_enabled(),
117 tracer_call_opts.clone(),
118 trace_result
119 );
120
121 match trace_result {
122 GethTrace::PreStateTracer(geth_trace_frame) => match geth_trace_frame {
123 PreStateFrame::Diff(diff_frame) => Ok((diff_frame.pre, diff_frame.post)),
124 PreStateFrame::Default(diff_frame) => Ok((diff_frame.0, Default::default())),
125 },
126 _ => Err(eyre::eyre!("TRACE_RESULT_FAILED")),
127 }
128}
129
130pub async fn debug_trace_call_pre_state<N: Network, C: DebugProviderExt<N>, TR: Into<N::TransactionRequest> + Send + Sync>(
131 client: C,
132 req: TR,
133 block: BlockId,
134 opts: Option<GethDebugTracingCallOptions>,
135) -> eyre::Result<GethStateUpdate> {
136 Ok(debug_trace_call(client, req, block, opts, false).await?.0)
137}
138
139pub async fn debug_trace_call_post_state<
140 N: Network<TransactionRequest = TransactionRequest>,
141 C: DebugProviderExt<N>,
142 TR: Into<TransactionRequest> + Send + Sync,
143>(
144 client: C,
145 req: TR,
146 block: BlockId,
147 opts: Option<GethDebugTracingCallOptions>,
148) -> eyre::Result<GethStateUpdate>
149where
150 <N as Network>::TransactionRequest: From<TR>,
151{
152 Ok(debug_trace_call(client, req, block, opts, true).await?.1)
153}
154
155pub async fn debug_trace_call_diff<N: Network, C: DebugProviderExt<N>, TR: Into<N::TransactionRequest> + Send + Sync>(
156 client: C,
157 req: TR,
158 block: BlockId,
159 call_opts: Option<GethDebugTracingCallOptions>,
160) -> eyre::Result<(GethStateUpdate, GethStateUpdate)> {
161 debug_trace_call(client, req, block, call_opts, true).await
162}
163
164pub async fn debug_trace_transaction<N: Network, P: Provider<N> + DebugApi<N>>(
165 client: P,
166 req: TxHash,
167 diff_mode: bool,
168) -> Result<(GethStateUpdate, GethStateUpdate)> {
169 let tracer_opts = GethDebugTracingOptions { config: GethDefaultTracingOptions::default(), ..GethDebugTracingOptions::default() }
170 .with_tracer(BuiltInTracer(PreStateTracer))
171 .with_prestate_config(PreStateConfig { diff_mode: Some(diff_mode), disable_code: Some(false), disable_storage: Some(false) });
172
173 let trace_result = client.debug_trace_transaction(req, tracer_opts).await?;
174 trace!("{:?}", trace_result);
175
176 match trace_result {
177 GethTrace::PreStateTracer(geth_trace_frame) => match geth_trace_frame {
178 PreStateFrame::Diff(diff_frame) => Ok((diff_frame.pre.into_iter().collect(), diff_frame.post.into_iter().collect())),
179 PreStateFrame::Default(diff_frame) => Ok((diff_frame.0.into_iter().collect(), Default::default())),
180 },
181 _ => Err(eyre::eyre!("TRACE_RESULT_FAILED")),
182 }
183}
184
185#[cfg(test)]
186mod test {
187 use super::*;
188 use alloy_primitives::map::B256HashMap;
189 use alloy_primitives::{B256, U256};
190
191 use alloy_provider::ProviderBuilder;
192 use alloy_rpc_client::{ClientBuilder, WsConnect};
193 use alloy_rpc_types::state::{AccountOverride, StateOverride};
194 use env_logger::Env as EnvLog;
195 use tracing::{debug, error};
196
197 #[tokio::test]
198 async fn test_debug_block() -> Result<()> {
199 dotenvy::from_filename(".env.test").ok();
200 let node_url = std::env::var("MAINNET_WS")?;
201
202 let _ = env_logger::try_init_from_env(EnvLog::default().default_filter_or("info,tokio_tungstenite=off,tungstenite=off"));
203 let node_url = url::Url::parse(node_url.as_str())?;
204
205 let ws_connect = WsConnect::new(node_url);
206 let client = ClientBuilder::default().ws(ws_connect).await?;
207
208 let client = ProviderBuilder::new().disable_recommended_fillers().connect_client(client);
209
210 let blocknumber = client.get_block_number().await?;
211 let _block = client.get_block_by_number(blocknumber.into()).await?.unwrap();
212
213 let _ret = debug_trace_block(client, BlockId::Number(blocknumber.into()), true).await?;
214
215 Ok(())
216 }
217
218 #[test]
219 fn test_encode_override() {
220 let mut state_override: StateOverride = StateOverride::default();
221 let address = Address::default();
222 let mut account_override: AccountOverride = AccountOverride::default();
223 let mut state_update_hashmap: B256HashMap<B256> = B256HashMap::default();
224 state_update_hashmap.insert(B256::from(U256::from(1)), B256::from(U256::from(3)));
225 account_override.state_diff = Some(state_update_hashmap);
226
227 state_override.insert(address, account_override);
228
229 match serde_json::to_string_pretty(&state_override) {
230 Ok(data) => {
231 debug!("{}", data);
232 }
233 Err(e) => {
234 error!("{}", e);
235 panic!("DESERIALIZATION_ERROR");
236 }
237 }
238 }
239}