kabu_types_blockchain/
state_update.rs

1use alloy_primitives::{Address, TxHash};
2use alloy_provider::ext::DebugApi;
3use alloy_provider::{Network, Provider};
4use alloy_rpc_types::{BlockId, TransactionRequest};
5use alloy_rpc_types_trace::common::TraceResult;
6use alloy_rpc_types_trace::geth::GethDebugBuiltInTracerType::PreStateTracer;
7use alloy_rpc_types_trace::geth::GethDebugTracerType::BuiltInTracer;
8use alloy_rpc_types_trace::geth::{
9    AccountState, GethDebugBuiltInTracerType, GethDebugTracerConfig, GethDebugTracerType, GethDebugTracingCallOptions,
10    GethDebugTracingOptions, GethDefaultTracingOptions, GethTrace, PreStateConfig, PreStateFrame,
11};
12use eyre::Result;
13use lazy_static::lazy_static;
14use std::collections::BTreeMap;
15use tracing::{debug, trace};
16
17use kabu_node_debug_provider::DebugProviderExt;
18
19pub type GethStateUpdate = BTreeMap<Address, AccountState>;
20
21pub type GethStateUpdateVec = Vec<BTreeMap<Address, AccountState>>;
22
23lazy_static! {
24    pub static ref TRACING_OPTS: GethDebugTracingOptions = GethDebugTracingOptions {
25        tracer: Some(GethDebugTracerType::BuiltInTracer(GethDebugBuiltInTracerType::PreStateTracer,)),
26        tracer_config: GethDebugTracerConfig::default(),
27        config: GethDefaultTracingOptions::default().disable_storage().disable_stack().disable_memory().disable_return_data(),
28        timeout: None,
29    };
30    pub static ref TRACING_CALL_OPTS: GethDebugTracingCallOptions =
31        GethDebugTracingCallOptions { tracing_options: TRACING_OPTS.clone(), state_overrides: None, block_overrides: None };
32}
33
34pub fn get_touched_addresses(state_update: &GethStateUpdate) -> Vec<Address> {
35    let mut ret: Vec<Address> = Vec::new();
36
37    for (address, state) in state_update.iter() {
38        if !state.storage.is_empty() {
39            ret.push(*address)
40        }
41    }
42
43    ret
44}
45
46pub fn debug_log_geth_state_update(state_update: &GethStateUpdate) {
47    for (address, state) in state_update {
48        debug!("{} nonce {:?} balance {:?} is_code {}", address, state.nonce, state.balance, state.code.is_some())
49    }
50}
51
52pub async fn debug_trace_block<N: Network, P: Provider<N> + DebugProviderExt<N>>(
53    client: P,
54    block_id: BlockId,
55    diff_mode: bool,
56) -> eyre::Result<(GethStateUpdateVec, GethStateUpdateVec)> {
57    let tracer_opts = GethDebugTracingOptions { config: GethDefaultTracingOptions::default(), ..GethDebugTracingOptions::default() }
58        .with_tracer(BuiltInTracer(PreStateTracer))
59        .with_prestate_config(PreStateConfig { diff_mode: Some(diff_mode), disable_code: Some(false), disable_storage: Some(false) });
60
61    let trace_result_vec = match block_id {
62        BlockId::Number(block_number) => client.geth_debug_trace_block_by_number(block_number, tracer_opts).await?,
63        BlockId::Hash(rpc_block_hash) => {
64            //client.debug_trace_block_by_number(BlockNumber::from(19776525u32), tracer_opts).await?
65            client.geth_debug_trace_block_by_hash(rpc_block_hash.block_hash, tracer_opts).await?
66        }
67    };
68
69    trace!("block trace {}", trace_result_vec.len());
70
71    let mut pre: GethStateUpdateVec = Default::default();
72    let mut post: GethStateUpdateVec = Default::default();
73
74    for trace_result in trace_result_vec.into_iter() {
75        if let TraceResult::Success { result, .. } = trace_result {
76            match result {
77                GethTrace::PreStateTracer(geth_trace_frame) => match geth_trace_frame {
78                    PreStateFrame::Diff(diff_frame) => {
79                        pre.push(diff_frame.pre);
80                        post.push(diff_frame.post);
81                    }
82                    PreStateFrame::Default(diff_frame) => {
83                        pre.push(diff_frame.0.into_iter().collect());
84                    }
85                },
86                _ => {
87                    return Err(eyre::eyre!("TRACE_RESULT_FAILED"));
88                }
89            }
90        }
91    }
92    Ok((pre, post))
93}
94
95async fn debug_trace_call<N: Network, C: DebugProviderExt<N>, TR: Into<N::TransactionRequest> + Send + Sync>(
96    client: C,
97    req: TR,
98    block: BlockId,
99    opts: Option<GethDebugTracingCallOptions>,
100    diff_mode: bool,
101) -> Result<(GethStateUpdate, GethStateUpdate)> {
102    let tracer_opts = GethDebugTracingOptions { config: GethDefaultTracingOptions::default(), ..GethDebugTracingOptions::default() }
103        .with_tracer(BuiltInTracer(PreStateTracer))
104        .with_prestate_config(PreStateConfig { diff_mode: Some(diff_mode), disable_code: Some(false), disable_storage: Some(false) });
105
106    let tracer_call_opts = GethDebugTracingCallOptions {
107        tracing_options: tracer_opts.clone(),
108        state_overrides: opts.clone().and_then(|x| x.state_overrides),
109        block_overrides: opts.and_then(|x| x.block_overrides),
110    };
111
112    let trace_result = client.geth_debug_trace_call(req.into(), block, tracer_call_opts.clone()).await?;
113    trace!(
114        "{} {} {:?} {:?}",
115        tracer_opts.config.is_stack_enabled(),
116        tracer_opts.config.is_storage_enabled(),
117        tracer_call_opts.clone(),
118        trace_result
119    );
120
121    match trace_result {
122        GethTrace::PreStateTracer(geth_trace_frame) => match geth_trace_frame {
123            PreStateFrame::Diff(diff_frame) => Ok((diff_frame.pre, diff_frame.post)),
124            PreStateFrame::Default(diff_frame) => Ok((diff_frame.0, Default::default())),
125        },
126        _ => Err(eyre::eyre!("TRACE_RESULT_FAILED")),
127    }
128}
129
130pub async fn debug_trace_call_pre_state<N: Network, C: DebugProviderExt<N>, TR: Into<N::TransactionRequest> + Send + Sync>(
131    client: C,
132    req: TR,
133    block: BlockId,
134    opts: Option<GethDebugTracingCallOptions>,
135) -> eyre::Result<GethStateUpdate> {
136    Ok(debug_trace_call(client, req, block, opts, false).await?.0)
137}
138
139pub async fn debug_trace_call_post_state<
140    N: Network<TransactionRequest = TransactionRequest>,
141    C: DebugProviderExt<N>,
142    TR: Into<TransactionRequest> + Send + Sync,
143>(
144    client: C,
145    req: TR,
146    block: BlockId,
147    opts: Option<GethDebugTracingCallOptions>,
148) -> eyre::Result<GethStateUpdate>
149where
150    <N as Network>::TransactionRequest: From<TR>,
151{
152    Ok(debug_trace_call(client, req, block, opts, true).await?.1)
153}
154
155pub async fn debug_trace_call_diff<N: Network, C: DebugProviderExt<N>, TR: Into<N::TransactionRequest> + Send + Sync>(
156    client: C,
157    req: TR,
158    block: BlockId,
159    call_opts: Option<GethDebugTracingCallOptions>,
160) -> eyre::Result<(GethStateUpdate, GethStateUpdate)> {
161    debug_trace_call(client, req, block, call_opts, true).await
162}
163
164pub async fn debug_trace_transaction<N: Network, P: Provider<N> + DebugApi<N>>(
165    client: P,
166    req: TxHash,
167    diff_mode: bool,
168) -> Result<(GethStateUpdate, GethStateUpdate)> {
169    let tracer_opts = GethDebugTracingOptions { config: GethDefaultTracingOptions::default(), ..GethDebugTracingOptions::default() }
170        .with_tracer(BuiltInTracer(PreStateTracer))
171        .with_prestate_config(PreStateConfig { diff_mode: Some(diff_mode), disable_code: Some(false), disable_storage: Some(false) });
172
173    let trace_result = client.debug_trace_transaction(req, tracer_opts).await?;
174    trace!("{:?}", trace_result);
175
176    match trace_result {
177        GethTrace::PreStateTracer(geth_trace_frame) => match geth_trace_frame {
178            PreStateFrame::Diff(diff_frame) => Ok((diff_frame.pre.into_iter().collect(), diff_frame.post.into_iter().collect())),
179            PreStateFrame::Default(diff_frame) => Ok((diff_frame.0.into_iter().collect(), Default::default())),
180        },
181        _ => Err(eyre::eyre!("TRACE_RESULT_FAILED")),
182    }
183}
184
185#[cfg(test)]
186mod test {
187    use super::*;
188    use alloy_primitives::map::B256HashMap;
189    use alloy_primitives::{B256, U256};
190
191    use alloy_provider::ProviderBuilder;
192    use alloy_rpc_client::{ClientBuilder, WsConnect};
193    use alloy_rpc_types::state::{AccountOverride, StateOverride};
194    use env_logger::Env as EnvLog;
195    use tracing::{debug, error};
196
197    #[tokio::test]
198    async fn test_debug_block() -> Result<()> {
199        dotenvy::from_filename(".env.test").ok();
200        let node_url = std::env::var("MAINNET_WS")?;
201
202        let _ = env_logger::try_init_from_env(EnvLog::default().default_filter_or("info,tokio_tungstenite=off,tungstenite=off"));
203        let node_url = url::Url::parse(node_url.as_str())?;
204
205        let ws_connect = WsConnect::new(node_url);
206        let client = ClientBuilder::default().ws(ws_connect).await?;
207
208        let client = ProviderBuilder::new().disable_recommended_fillers().connect_client(client);
209
210        let blocknumber = client.get_block_number().await?;
211        let _block = client.get_block_by_number(blocknumber.into()).await?.unwrap();
212
213        let _ret = debug_trace_block(client, BlockId::Number(blocknumber.into()), true).await?;
214
215        Ok(())
216    }
217
218    #[test]
219    fn test_encode_override() {
220        let mut state_override: StateOverride = StateOverride::default();
221        let address = Address::default();
222        let mut account_override: AccountOverride = AccountOverride::default();
223        let mut state_update_hashmap: B256HashMap<B256> = B256HashMap::default();
224        state_update_hashmap.insert(B256::from(U256::from(1)), B256::from(U256::from(3)));
225        account_override.state_diff = Some(state_update_hashmap);
226
227        state_override.insert(address, account_override);
228
229        match serde_json::to_string_pretty(&state_override) {
230            Ok(data) => {
231                debug!("{}", data);
232            }
233            Err(e) => {
234                error!("{}", e);
235                panic!("DESERIALIZATION_ERROR");
236            }
237        }
238    }
239}