Skip to content

Commit

Permalink
✨ chain copy success
Browse files Browse the repository at this point in the history
  • Loading branch information
MeditationDuck committed Oct 4, 2024
1 parent f1c00a8 commit ae04dd6
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 57 deletions.
3 changes: 3 additions & 0 deletions wake/development/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,9 @@ class Chain(ABC):

tx_callback: Optional[Callable[[TransactionAbc], None]]

def __deepcopy__(self, memo):
return self

@abstractmethod
def _connect_setup(
self, min_gas_price: Optional[int], block_base_fee_per_gas: Optional[int]
Expand Down
168 changes: 111 additions & 57 deletions wake/testing/fuzzing/fuzz_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


from collections import defaultdict
from typing import Callable, DefaultDict, List, Optional, Any
from typing import Callable, DefaultDict, List, Optional, Any, Tuple

from typing_extensions import get_type_hints

Expand All @@ -26,6 +26,7 @@
import traceback
from wake.utils.file_utils import is_relative_to
from wake.development.transactions import Error
import copy


def __get_methods(target, attr: str) -> List[Callable]:
Expand All @@ -38,6 +39,94 @@ def __get_methods(target, attr: str) -> List[Callable]:
return ret


def compare_exceptions(e1, e2):
if type(e1) != type(e2):
# print("type not equal")
return False

if type(e1) == Error and type(e2) == Error:
# if error was transaction message error the compare message content as well
if e1.message != e2.message:
return False

tb1 = traceback.extract_tb(e1.__traceback__)
tb2 = traceback.extract_tb(e2.__traceback__)

frame1 = None
for frame1 in tb1:
if is_relative_to(
Path(frame1.filename), Path.cwd()
) and not is_relative_to(
Path(frame1.filename), Path().cwd() / "pytypes"
):
break
frame2 = None
for frame2 in tb2:
if is_relative_to(
Path(frame2.filename), Path.cwd()
) and not is_relative_to(
Path(frame2.filename), Path().cwd() / "pytypes"
):
break

if frame1 is None or frame2 is None:
print("frame is none!!!!!!!!!!!!!!")
# return False
if frame1 is not None and frame2 is not None and (frame1.filename != frame2.filename
or frame1.lineno != frame2.lineno
or frame1.name != frame2.name
):
return False
return True


class StateSnapShot:
_python_state: FuzzTest | None
chain_states: List[str]
flow_number: int | None # current flow number

def __init__(self):
self._python_state = None
self.chain_states = []
self.flow_number = None

def take_snapshot(self, python_instance: FuzzTest, new_instance, chains: Tuple[Chain, ...], overwrite: bool):
if not overwrite:
assert self._python_state is None, "Python state already exists"
assert self.chain_states == [], "Chain state already exists"
else:
assert self._python_state is not None, "No python state (snapshot)"
assert self.chain_states != [], "No chain state"
assert self.flow_number is not None, "No flow number"
print("overwriting state ", self.flow_number, " to ", python_instance._flow_num)
# assert self._python_state is None, "Python state already exists"

self._python_state = new_instance

self.flow_number = python_instance._flow_num
self._python_state.__dict__.update(copy.deepcopy(python_instance.__dict__))
self.chain_states = [chain.snapshot() for chain in chains]


def revert(self, python_instance: FuzzTest, chains: Tuple[Chain, ...]):
assert self.chain_states != [], "No chain snapshot"
assert self._python_state is not None, "No python state"
assert self.flow_number is not None, "No flow number"

print("curr", python_instance._flow_num)
print("new", self._python_state._flow_num)
# assert python_instance._flow_num != self._python_state._flow_num, "Flow number mismatch"
python_instance.__dict__.update(copy.deepcopy(self._python_state.__dict__))

assert python_instance._flow_num == self._python_state._flow_num, "update failed"
self._python_state = None

for temp_chain, chain in zip(self.chain_states, chains):
chain.revert(temp_chain)
self.chain_states = []



@dataclass
class FlowState:
random_state: bytes
Expand Down Expand Up @@ -134,7 +223,7 @@ def shrink_test(test_class: type[FuzzTest], flows_count, dry_run: bool = False):
invariant_periods: DefaultDict[Callable[[None], None], int] = defaultdict(int)

# Snapshot all connected chains
snapshots = [chain.snapshot() for chain in chains]
initial_chain_state_snapshots = [chain.snapshot() for chain in chains]

initial_state = get_sequence_initial_internal_state() # argument

Expand Down Expand Up @@ -223,8 +312,9 @@ def shrink_test(test_class: type[FuzzTest], flows_count, dry_run: bool = False):
exception = True
assert test_instance._flow_num == error_flow_num, "Unexpected failing flow"
finally:
for snapshot, chain in zip(snapshots, chains):
for snapshot, chain in zip(initial_chain_state_snapshots, chains):
chain.revert(snapshot)
initial_chain_state_snapshots = []
if exception == False:
raise Exception("Exception not raised unexpected state changes")

Expand All @@ -237,26 +327,31 @@ class OverRunException(Exception):
def __init__(self):
super().__init__("Overrun")


random.setstate(pickle.loads(initial_state))

test_instance._flow_num = 0
test_instance._sequence_num = 0
test_instance.pre_sequence()

states = StateSnapShot()
states.take_snapshot(test_instance,test_class(), chains, overwrite=False)

while curr <= error_flow_num:
assert flow_state[curr].required == True
flow_state[curr].required = False
flows_counter: DefaultDict[Callable, int] = defaultdict(int)
invariant_periods: DefaultDict[Callable[[None], None], int] = defaultdict(
int
)
snapshots = [chain.snapshot() for chain in chains]
print("progress: ", (curr* 100) / (error_flow_num+1), "%")
random.setstate(pickle.loads(initial_state))

test_instance._flow_num = 0
test_instance._sequence_num = 0
test_instance.pre_sequence()
exception = False
try:
for j in range(0, flows_count):
if j > error_flow_num:
raise OverRunException()

print("flow: ", j, flow_state[j].flow.__name__ )

curr_flow_state = flow_state[j]
random.setstate(pickle.loads(curr_flow_state.random_state))
flow = curr_flow_state.flow
Expand All @@ -267,7 +362,6 @@ def __init__(self):
test_instance._flow_num = j
test_instance.pre_flow(flow)
flow(test_instance, *flow_params)
flows_counter[flow] += 1
test_instance.post_flow(flow)


Expand All @@ -292,46 +386,6 @@ def __init__(self):
except Exception as e:
exception = True

def compare_exceptions(e1, e2):
if type(e1) != type(e2):
# print("type not equal")
return False

if type(e1) == Error and type(e2) == Error:
# if error was transaction message error the compare message content as well
if e1.message != e2.message:
return False

tb1 = traceback.extract_tb(e1.__traceback__)
tb2 = traceback.extract_tb(e2.__traceback__)

frame1 = None
for frame1 in tb1:
if is_relative_to(
Path(frame1.filename), Path.cwd()
) and not is_relative_to(
Path(frame1.filename), Path().cwd() / "pytypes"
):
break
frame2 = None
for frame2 in tb2:
if is_relative_to(
Path(frame2.filename), Path.cwd()
) and not is_relative_to(
Path(frame2.filename), Path().cwd() / "pytypes"
):
break

if frame1 is None or frame2 is None:
print("frame is none!!!!!!!!!!!!!!")
# return False
if frame1 is not None and frame2 is not None and (frame1.filename != frame2.filename
or frame1.lineno != frame2.lineno
or frame1.name != frame2.name
):
return False
return True

# Check exception type and exception lines in the testing file.
ignore_flows = True

Expand All @@ -349,8 +403,11 @@ def compare_exceptions(e1, e2):

finally:
# revert to starting state
for snapshot, chain in zip(snapshots, chains):
chain.revert(snapshot)
# for snapshot, chain in zip(initial_chain_state_snapshots, chains):
# chain.revert(snapshot)
print("revert state!!")
states.revert(test_instance, chains)
states.take_snapshot(test_instance, test_class(), chains, overwrite=False)

if exception == False:
print("overrun!")
Expand Down Expand Up @@ -419,6 +476,7 @@ def single_fuzz_test(test_class: type[FuzzTest], sequences_count: int, flows_cou
test_instance._sequence_num = i
test_instance.pre_sequence()


for j in range(flows_count):
valid_flows = [
f
Expand Down Expand Up @@ -483,7 +541,3 @@ def single_fuzz_test(test_class: type[FuzzTest], sequences_count: int, flows_cou
# Revert all chains back to their initial snapshot
for snapshot, chain in zip(snapshots, chains):
chain.revert(snapshot)



#hehe

0 comments on commit ae04dd6

Please sign in to comment.