Skip to content

Commit

Permalink
✏️ chain copy elements for snapshot
Browse files Browse the repository at this point in the history
  • Loading branch information
MeditationDuck committed Dec 10, 2024
1 parent 6c57324 commit fba0c52
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"wake.compiler.solc.remappings": []
}
38 changes: 38 additions & 0 deletions wake/testing/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,24 @@ def snapshot(self) -> str:
"txs": dict(self._txs._transactions),
"tx_hashes": list(self._txs._tx_hashes),
"blocks": dict(self._blocks._blocks),
## for testing and shrinking
"accounts_set": self._accounts_set.copy(),
"default_estimate_account": self._default_estimate_account,
"default_access_list_account": self._default_access_list_account,
"default_tx_type": self._default_tx_type,
"default_tx_confirmations": self._default_tx_confirmations,
"deployed_libraries": dict(self._deployed_libraries),
"single_source_errors": self._single_source_errors.copy(),
"chain_id": self._chain_id,
"labels": self._labels.copy(),
"require_signed_txs": self._require_signed_txs,
"fork": self._fork,
"forked_chain_id": self._forked_chain_id,
"debug_trace_call_supported": self._debug_trace_call_supported,
"client_version": self._client_version,
"gas_price": self._gas_price,
"max_priority_fee_per_gas": self._max_priority_fee_per_gas,
"initial_base_fee_per_gas": self._initial_base_fee_per_gas,
}
return snapshot_id

Expand All @@ -124,6 +142,26 @@ def revert(self, snapshot_id: str) -> None:
self._txs._transactions = snapshot["txs"]
self._txs._tx_hashes = snapshot["tx_hashes"]
self._blocks._blocks = snapshot["blocks"]

# for testing and shrinking
self._accounts_set = snapshot["accounts_set"]
self._default_estimate_account = snapshot["default_estimate_account"]
self._default_access_list_account = snapshot["default_access_list_account"]
self._default_tx_type = snapshot["default_tx_type"]
self._default_tx_confirmations = snapshot["default_tx_confirmations"]
self._deployed_libraries = snapshot["deployed_libraries"]
self._single_source_errors = snapshot["single_source_errors"]
self._chain_id = snapshot["chain_id"]
self._labels = snapshot["labels"]
self._require_signed_txs = snapshot["require_signed_txs"]
self._fork = snapshot["fork"]
self._forked_chain_id = snapshot["forked_chain_id"]
self._debug_trace_call_supported = snapshot["debug_trace_call_supported"]
self._client_version = snapshot["client_version"]
self._gas_price = snapshot["gas_price"]
self._max_priority_fee_per_gas = snapshot["max_priority_fee_per_gas"]
self._initial_base_fee_per_gas = snapshot["initial_base_fee_per_gas"]

del self._snapshots[snapshot_id]

@property
Expand Down
8 changes: 7 additions & 1 deletion wake/testing/fuzzing/fuzz_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from wake.cli.console import console
from contextlib import contextmanager, redirect_stdout, redirect_stderr

from wake.testing.core import default_chain as global_default_chain

EXACT_FLOW_INDEX = False # False if you accept it could reproduce same error earlier.

Expand Down Expand Up @@ -109,6 +110,7 @@ class StateSnapShot:
chain_states: List[str]
flow_number: int | None # Current flow number
random_state: Any | None
default_chain: Chain | None

def __init__(self):
self._python_state = None
Expand All @@ -124,19 +126,22 @@ def take_snapshot(self, python_instance: FuzzTest, new_instance, chains: Tuple[C
assert self.chain_states != [], "Chain state is missing"
assert self.flow_number is not None, "Flow number is missing"
print("Overwriting state ", self.flow_number, " to ", python_instance._flow_num)
assert self.default_chain is not None, "Default chain is missing"
# assert self._python_state is None, "Python state already exists"

self._python_state = new_instance

self.flow_number = python_instance._flow_num
self._python_state.__dict__.update(copy.deepcopy(python_instance.__dict__))
self.chain_states = [chain.snapshot() for chain in chains]
self.default_chain = global_default_chain
self.random_state = random_state

def revert(self, python_instance: FuzzTest, chains: Tuple[Chain, ...], with_random_state: bool = False):
global global_default_chain
assert self.chain_states != [], "Chain snapshot is missing"
assert self._python_state is not None, "Python state snapshot is missing "
assert self.flow_number is not None, "Flow number is missing"
assert self.default_chain is not None, "Default chain is missing"

python_instance.__dict__ = self._python_state.__dict__

Expand All @@ -147,6 +152,7 @@ def revert(self, python_instance: FuzzTest, chains: Tuple[Chain, ...], with_rand
if with_random_state:
assert self.random_state is not None, "Random state is missing"
random.setstate(self.random_state)
global_default_chain = self.default_chain


class OverRunException(Exception):
Expand Down

0 comments on commit fba0c52

Please sign in to comment.