Skip to content

Commit

Permalink
✏️ fix and remove random state argument
Browse files Browse the repository at this point in the history
  • Loading branch information
MeditationDuck committed Dec 17, 2024
1 parent c309f7f commit 32f4264
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 71 deletions.
52 changes: 5 additions & 47 deletions wake/cli/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,6 @@ def shell_complete(
type=str,
help="Random seeds",
)
@click.option(
"--random-state",
"-RS",
"random_states",
multiple=True,
type=str,
help="Random statuses",
)
@click.option(
"--attach-first",
is_flag=True,
Expand Down Expand Up @@ -129,7 +121,6 @@ def run_test(
proc_count: Optional[int],
coverage: int,
seeds: Tuple[str],
random_states: Tuple[str],
attach_first: bool,
dist: str,
verbosity: int,
Expand Down Expand Up @@ -164,13 +155,6 @@ def run_test(
except ValueError:
raise click.BadParameter("Seeds must be hex numbers.")

try:
random_states_byte = [
bytes.fromhex(random_state) for random_state in random_states
]
except ValueError:
raise click.BadParameter("Random states must be hex numbers.")

config = WakeConfig(local_config_path=context.obj.get("local_config_path", None))
config.load_configs()

Expand Down Expand Up @@ -221,7 +205,6 @@ def run_test(
coverage,
proc_count,
random_seeds,
random_states_byte,
attach_first,
debug,
dist,
Expand Down Expand Up @@ -263,7 +246,7 @@ def extract_crash_log_dict(crash_log_file_path: Path) -> dict:
except json.JSONDecodeError:
raise ValueError(f"Invalid JSON format in crash log file: {crash_log_file_path}")

def get_shrink_argument_path(shrink_path_str: str) -> Path:
def get_shrink_argument_path(shrink_path_str: str, dir_name: str) -> Path:
try:
path = Path(shrink_path_str)
if not path.exists():
Expand All @@ -273,7 +256,7 @@ def get_shrink_argument_path(shrink_path_str: str) -> Path:
pass

crash_logs_dir = (
get_config().project_root_path / ".wake" / "logs" / "crashes"
get_config().project_root_path / ".wake" / "logs" / dir_name
)
if not crash_logs_dir.exists():
raise click.BadParameter(
Expand All @@ -287,41 +270,16 @@ def get_shrink_argument_path(shrink_path_str: str) -> Path:
raise click.BadParameter(f"Invalid crash log index: {index}")
return Path(crash_logs[index])

def get_shrank_argument_path(shrank_path_str: str) -> Path:
try:
shrank_path = Path(shrank_path_str)
if not shrank_path.exists():
raise ValueError(f"Shrank data file not found: {shrank_path}")
return shrank_path
except ValueError:
pass
shrank_data_path = (
get_config().project_root_path / ".wake" / "logs" / "shrank"
)
if not shrank_data_path.exists():
raise click.BadParameter(
f"Shrank data file not found: {shrank_data_path}"
)

index = int(shrank_path_str)
shrank_files = sorted(
shrank_data_path.glob("*.bin"), key=os.path.getmtime, reverse=True
)
if abs(index) > len(shrank_files):
raise click.BadParameter(f"Invalid crash log index: {index}")
return Path(shrank_files[index])

if shrank is not None and shrink is not None:
raise click.BadParameter(
"Both shrink and shrieked cannot be provided at the same time."
)

pytest_path_specified, test_path = get_single_test_path(pytest_args)


if shrink is not None:
set_fuzz_mode(1)
shrink_crash_path = get_shrink_argument_path(shrink)
shrink_crash_path = get_shrink_argument_path(shrink, "crashes")
print("shrink from crash log: ", shrink_crash_path)
crash_log_dict = extract_crash_log_dict(shrink_crash_path)
path = crash_log_dict["test_file"]
Expand All @@ -336,7 +294,7 @@ def get_shrank_argument_path(shrank_path_str: str) -> Path:

if shrank:
set_fuzz_mode(2)
shrank_data_path = get_shrank_argument_path(shrank)
shrank_data_path = get_shrink_argument_path(shrank, "shrank")
print("shrank from shrank data: ", shrank_data_path)
with open(shrank_data_path, "r") as f:
target_fuzz_path = json.load(f)["target_fuzz_path"]
Expand All @@ -352,7 +310,7 @@ def get_shrank_argument_path(shrank_path_str: str) -> Path:
pytest_args,
plugins=[
PytestWakePluginSingle(
config, debug, coverage, random_seeds, random_states_byte
config, debug, coverage, random_seeds
)
],
)
Expand Down
12 changes: 3 additions & 9 deletions wake/testing/pytest_plugin_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class PytestWakePluginMultiprocess:
_log_file: Path
_crash_log_file: Path
_random_seed: bytes
_random_state: Optional[bytes]
_tee: bool
_debug: bool
_exception_handled: bool
Expand All @@ -62,7 +61,6 @@ def __init__(
log_dir: Path,
crash_log_dir: Path,
random_seed: bytes,
random_state: Optional[bytes],
tee: bool,
debug: bool,
):
Expand All @@ -74,7 +72,6 @@ def __init__(
self._log_file = log_dir / sanitize_filename(f"process-{index}.ansi")
self._crash_log_dir = crash_log_dir
self._random_seed = random_seed
self._random_state = random_state
self._tee = tee
self._debug = debug
self._exception_handled = False
Expand Down Expand Up @@ -309,12 +306,9 @@ def sigint_handler(signum, frame):
indexes = self._conn.recv()
for i in range(len(indexes)):
# set random seed before each test item
if self._random_state is not None:
random.setstate(pickle.loads(self._random_state))
console.print(f"Using random state '{random.getstate()[1]}'")
else:
random.seed(self._random_seed)
console.print(f"Setting random seed '{self._random_seed.hex()}'")

random.seed(self._random_seed)
console.print(f"Setting random seed '{self._random_seed.hex()}'")

item = session.items[indexes[i]]
nextitem = (
Expand Down
4 changes: 0 additions & 4 deletions wake/testing/pytest_plugin_multiprocess_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class PytestWakePluginMultiprocessServer:
int, Tuple[multiprocessing.Process, multiprocessing.connection.Connection]
]
_random_seeds: List[bytes]
_random_states: Optional[List[bytes]]
_attach_first: bool
_debug: bool
_pytest_args: List[str]
Expand All @@ -48,7 +47,6 @@ def __init__(
coverage: int,
proc_count: int,
random_seeds: List[bytes],
raondom_states: Optional[List[bytes]],
attach_first: bool,
debug: bool,
dist: str,
Expand All @@ -59,7 +57,6 @@ def __init__(
self._proc_count = proc_count
self._processes = {}
self._random_seeds = random_seeds
self._random_states = raondom_states
self._attach_first = attach_first
self._debug = debug
self._dist = dist
Expand Down Expand Up @@ -103,7 +100,6 @@ def pytest_sessionstart(self, session: pytest.Session):
logs_dir,
crash_logs_process_dir,
self._random_seeds[i],
self._random_states[i] if self._random_states else None,
self._attach_first and i == 0,
self._debug,
),
Expand Down
13 changes: 2 additions & 11 deletions wake/testing/pytest_plugin_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class PytestWakePluginSingle:
_config: WakeConfig
_cov_proc_count: Optional[int]
_random_seeds: List[bytes]
_random_states: List[Optional[bytes]]
_debug: bool

def __init__(
Expand All @@ -38,13 +37,11 @@ def __init__(
debug: bool,
cov_proc_count: Optional[int],
random_seeds: Iterable[bytes],
random_states: Iterable[Optional[bytes]],
):
self._config = config
self._debug = debug
self._cov_proc_count = cov_proc_count
self._random_seeds = list(random_seeds)
self._random_states = list(random_states)

def pytest_runtest_setup(self, item):
reset_exception_handled()
Expand Down Expand Up @@ -124,14 +121,8 @@ def pytest_runtestloop(self, session: Session):

coverage = self._cov_proc_count == 1 or self._cov_proc_count == -1


if len(self._random_states) > 0:
assert self._random_states[0] is not None
random.setstate(pickle.loads(self._random_states[0]))
console.print(f"Using random state '{random.getstate()[1]}'")
else:
random.seed(self._random_seeds[0])
console.print(f"Using random seed '{self._random_seeds[0].hex()}'")
random.seed(self._random_seeds[0])
console.print(f"Using random seed '{self._random_seeds[0].hex()}'")

if self._debug:
set_exception_handler(partial(attach_debugger, seed=self._random_seeds[0]))
Expand Down

0 comments on commit 32f4264

Please sign in to comment.