From 32f4264e5c7adaa6f1ea317ebc8cc52ea45f7ec2 Mon Sep 17 00:00:00 2001 From: MeditationDuck Date: Tue, 17 Dec 2024 13:26:19 +0100 Subject: [PATCH] :pencil2: fix and remove random state argument --- wake/cli/test.py | 52 ++----------------- wake/testing/pytest_plugin_multiprocess.py | 12 ++--- .../pytest_plugin_multiprocess_server.py | 4 -- wake/testing/pytest_plugin_single.py | 13 +---- 4 files changed, 10 insertions(+), 71 deletions(-) diff --git a/wake/cli/test.py b/wake/cli/test.py index d1759461..105f338e 100644 --- a/wake/cli/test.py +++ b/wake/cli/test.py @@ -70,14 +70,6 @@ def shell_complete( type=str, help="Random seeds", ) -@click.option( - "--random-state", - "-RS", - "random_states", - multiple=True, - type=str, - help="Random statuses", -) @click.option( "--attach-first", is_flag=True, @@ -129,7 +121,6 @@ def run_test( proc_count: Optional[int], coverage: int, seeds: Tuple[str], - random_states: Tuple[str], attach_first: bool, dist: str, verbosity: int, @@ -164,13 +155,6 @@ def run_test( except ValueError: raise click.BadParameter("Seeds must be hex numbers.") - try: - random_states_byte = [ - bytes.fromhex(random_state) for random_state in random_states - ] - except ValueError: - raise click.BadParameter("Random states must be hex numbers.") - config = WakeConfig(local_config_path=context.obj.get("local_config_path", None)) config.load_configs() @@ -221,7 +205,6 @@ def run_test( coverage, proc_count, random_seeds, - random_states_byte, attach_first, debug, dist, @@ -263,7 +246,7 @@ def extract_crash_log_dict(crash_log_file_path: Path) -> dict: except json.JSONDecodeError: raise ValueError(f"Invalid JSON format in crash log file: {crash_log_file_path}") - def get_shrink_argument_path(shrink_path_str: str) -> Path: + def get_shrink_argument_path(shrink_path_str: str, dir_name: str) -> Path: try: path = Path(shrink_path_str) if not path.exists(): @@ -273,7 +256,7 @@ def get_shrink_argument_path(shrink_path_str: str) -> Path: pass crash_logs_dir = ( - get_config().project_root_path / ".wake" / "logs" / "crashes" + get_config().project_root_path / ".wake" / "logs" / dir_name ) if not crash_logs_dir.exists(): raise click.BadParameter( @@ -287,30 +270,6 @@ def get_shrink_argument_path(shrink_path_str: str) -> Path: raise click.BadParameter(f"Invalid crash log index: {index}") return Path(crash_logs[index]) - def get_shrank_argument_path(shrank_path_str: str) -> Path: - try: - shrank_path = Path(shrank_path_str) - if not shrank_path.exists(): - raise ValueError(f"Shrank data file not found: {shrank_path}") - return shrank_path - except ValueError: - pass - shrank_data_path = ( - get_config().project_root_path / ".wake" / "logs" / "shrank" - ) - if not shrank_data_path.exists(): - raise click.BadParameter( - f"Shrank data file not found: {shrank_data_path}" - ) - - index = int(shrank_path_str) - shrank_files = sorted( - shrank_data_path.glob("*.bin"), key=os.path.getmtime, reverse=True - ) - if abs(index) > len(shrank_files): - raise click.BadParameter(f"Invalid crash log index: {index}") - return Path(shrank_files[index]) - if shrank is not None and shrink is not None: raise click.BadParameter( "Both shrink and shrieked cannot be provided at the same time." @@ -318,10 +277,9 @@ def get_shrank_argument_path(shrank_path_str: str) -> Path: pytest_path_specified, test_path = get_single_test_path(pytest_args) - if shrink is not None: set_fuzz_mode(1) - shrink_crash_path = get_shrink_argument_path(shrink) + shrink_crash_path = get_shrink_argument_path(shrink, "crashes") print("shrink from crash log: ", shrink_crash_path) crash_log_dict = extract_crash_log_dict(shrink_crash_path) path = crash_log_dict["test_file"] @@ -336,7 +294,7 @@ def get_shrank_argument_path(shrank_path_str: str) -> Path: if shrank: set_fuzz_mode(2) - shrank_data_path = get_shrank_argument_path(shrank) + shrank_data_path = get_shrink_argument_path(shrank, "shrank") print("shrank from shrank data: ", shrank_data_path) with open(shrank_data_path, "r") as f: target_fuzz_path = json.load(f)["target_fuzz_path"] @@ -352,7 +310,7 @@ def get_shrank_argument_path(shrank_path_str: str) -> Path: pytest_args, plugins=[ PytestWakePluginSingle( - config, debug, coverage, random_seeds, random_states_byte + config, debug, coverage, random_seeds ) ], ) diff --git a/wake/testing/pytest_plugin_multiprocess.py b/wake/testing/pytest_plugin_multiprocess.py index 0f812f37..5b0508ed 100644 --- a/wake/testing/pytest_plugin_multiprocess.py +++ b/wake/testing/pytest_plugin_multiprocess.py @@ -44,7 +44,6 @@ class PytestWakePluginMultiprocess: _log_file: Path _crash_log_file: Path _random_seed: bytes - _random_state: Optional[bytes] _tee: bool _debug: bool _exception_handled: bool @@ -62,7 +61,6 @@ def __init__( log_dir: Path, crash_log_dir: Path, random_seed: bytes, - random_state: Optional[bytes], tee: bool, debug: bool, ): @@ -74,7 +72,6 @@ def __init__( self._log_file = log_dir / sanitize_filename(f"process-{index}.ansi") self._crash_log_dir = crash_log_dir self._random_seed = random_seed - self._random_state = random_state self._tee = tee self._debug = debug self._exception_handled = False @@ -309,12 +306,9 @@ def sigint_handler(signum, frame): indexes = self._conn.recv() for i in range(len(indexes)): # set random seed before each test item - if self._random_state is not None: - random.setstate(pickle.loads(self._random_state)) - console.print(f"Using random state '{random.getstate()[1]}'") - else: - random.seed(self._random_seed) - console.print(f"Setting random seed '{self._random_seed.hex()}'") + + random.seed(self._random_seed) + console.print(f"Setting random seed '{self._random_seed.hex()}'") item = session.items[indexes[i]] nextitem = ( diff --git a/wake/testing/pytest_plugin_multiprocess_server.py b/wake/testing/pytest_plugin_multiprocess_server.py index 2f4c9fbf..26e221c5 100644 --- a/wake/testing/pytest_plugin_multiprocess_server.py +++ b/wake/testing/pytest_plugin_multiprocess_server.py @@ -33,7 +33,6 @@ class PytestWakePluginMultiprocessServer: int, Tuple[multiprocessing.Process, multiprocessing.connection.Connection] ] _random_seeds: List[bytes] - _random_states: Optional[List[bytes]] _attach_first: bool _debug: bool _pytest_args: List[str] @@ -48,7 +47,6 @@ def __init__( coverage: int, proc_count: int, random_seeds: List[bytes], - raondom_states: Optional[List[bytes]], attach_first: bool, debug: bool, dist: str, @@ -59,7 +57,6 @@ def __init__( self._proc_count = proc_count self._processes = {} self._random_seeds = random_seeds - self._random_states = raondom_states self._attach_first = attach_first self._debug = debug self._dist = dist @@ -103,7 +100,6 @@ def pytest_sessionstart(self, session: pytest.Session): logs_dir, crash_logs_process_dir, self._random_seeds[i], - self._random_states[i] if self._random_states else None, self._attach_first and i == 0, self._debug, ), diff --git a/wake/testing/pytest_plugin_single.py b/wake/testing/pytest_plugin_single.py index 09f5a135..bd813e89 100644 --- a/wake/testing/pytest_plugin_single.py +++ b/wake/testing/pytest_plugin_single.py @@ -29,7 +29,6 @@ class PytestWakePluginSingle: _config: WakeConfig _cov_proc_count: Optional[int] _random_seeds: List[bytes] - _random_states: List[Optional[bytes]] _debug: bool def __init__( @@ -38,13 +37,11 @@ def __init__( debug: bool, cov_proc_count: Optional[int], random_seeds: Iterable[bytes], - random_states: Iterable[Optional[bytes]], ): self._config = config self._debug = debug self._cov_proc_count = cov_proc_count self._random_seeds = list(random_seeds) - self._random_states = list(random_states) def pytest_runtest_setup(self, item): reset_exception_handled() @@ -124,14 +121,8 @@ def pytest_runtestloop(self, session: Session): coverage = self._cov_proc_count == 1 or self._cov_proc_count == -1 - - if len(self._random_states) > 0: - assert self._random_states[0] is not None - random.setstate(pickle.loads(self._random_states[0])) - console.print(f"Using random state '{random.getstate()[1]}'") - else: - random.seed(self._random_seeds[0]) - console.print(f"Using random seed '{self._random_seeds[0].hex()}'") + random.seed(self._random_seeds[0]) + console.print(f"Using random seed '{self._random_seeds[0].hex()}'") if self._debug: set_exception_handler(partial(attach_debugger, seed=self._random_seeds[0]))