Skip to content

Commit

Permalink
Additional updates to results tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jlumpe committed Dec 1, 2024
1 parent 881c87c commit bc6e177
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 31 deletions.
63 changes: 36 additions & 27 deletions tests/cli/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest

from gambit.seq import SequenceFile
from gambit.query import QueryInput, QueryResults
from gambit.query import QueryResults
from gambit.util.misc import zip_strict
from gambit.util.io import write_lines, FilePath
from gambit.cli.common import strip_seq_file_ext
Expand All @@ -20,13 +20,13 @@


def make_args(testdb: TestDB, *,
positional_files: Optional[Iterable[SequenceFile]] = None,
list_file: Optional['FilePath'] = None,
sig_file: bool = False,
output: Optional['FilePath'] = None,
outfmt: Optional[str] = None,
strict: bool=False,
) -> list[str]:
positional_files: Optional[Iterable[SequenceFile]] = None,
list_file: Optional['FilePath'] = None,
sig_file: bool = False,
output: Optional['FilePath'] = None,
outfmt: Optional[str] = None,
strict: bool=False,
) -> list[str]:
"""Make command line arguments for querying."""

args: list[str] = [f'--db={testdb.paths.root}', 'query']
Expand All @@ -50,16 +50,29 @@ def make_args(testdb: TestDB, *,
return args


def make_ref_results(testdb: TestDB, inputs: Iterable[QueryInput], strict: bool, nqueries: Optional[int]):
def make_ref_results(testdb: TestDB,
labels: Iterable[str],
strict: bool,
files: Optional[Iterable[FilePath]],
nqueries: Optional[int] = None,
):
"""
Make a copy of the reference query results to compare to, modifying to account for possibly
different query inputs and # of queries.
different query labels/files and # of queries.
"""
ref_results = copy(testdb.get_query_results(strict))
ref_results.items = ref_results.items[:nqueries]

for item, input in zip_strict(ref_results.items, inputs):
item.input = input
for item, label in zip_strict(ref_results.items, labels):
item.label = label

if files is None:
for item in ref_results.items:
item.file = None

if files is not None:
for item, file in zip_strict(ref_results.items, files):
item.file = Path(file)

return ref_results

Expand Down Expand Up @@ -94,21 +107,18 @@ def check_results(results_file: Path, out_fmt: str, ref_results: QueryResults):
],
)
def test_full_query(testdb: TestDB,
nqueries: Optional[int],
use_list_file: bool,
out_fmt: str,
strict: bool,
gzipped: bool,
tmp_path: Path,
):
nqueries: Optional[int],
use_list_file: bool,
out_fmt: str,
strict: bool,
gzipped: bool,
tmp_path: Path,
):
"""Run a full query using the command line interface."""

query_files = testdb.get_query_files(gzipped)[:nqueries]
inputs = [
QueryInput(strip_seq_file_ext(file.path.name), file)
for file in query_files
]
ref_results: QueryResults = make_ref_results(testdb, inputs, strict, nqueries)
query_files = [sigfile.path for sigfile in testdb.get_query_files(gzipped)[:nqueries]]
labels = [strip_seq_file_ext(file.name) for file in query_files]
ref_results: QueryResults = make_ref_results(testdb, labels, strict, query_files, nqueries=nqueries)

results_file = tmp_path / ('results.' + out_fmt)

Expand Down Expand Up @@ -137,8 +147,7 @@ def test_full_query(testdb: TestDB,
def test_sigfile(testdb: TestDB, out_fmt: str, strict: bool, tmp_path: Path):
"""Test using signature file instead of parsing genome files."""

inputs = list(map(QueryInput, testdb.query_signatures.ids))
ref_results = make_ref_results(testdb, inputs, strict, None)
ref_results = make_ref_results(testdb, testdb.query_signatures.ids, strict, None)

results_file = tmp_path / ('results.' + out_fmt)

Expand Down
8 changes: 4 additions & 4 deletions tests/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,12 @@ def cmp_genomematch_json(data, match: GenomeMatch):
cmp_taxon_json(data['matched_taxon'], match.matched_taxon)


def check_json_results(file: TextIO,
results: QueryResults,
strict: bool = False,
):
def check_json_results(file: TextIO, results: QueryResults, strict: bool = False):
"""Assert exported JSON data matches the given results object.
"Strict" mode also compares the ``timestamp``, ``gambit_version``, and ``extra`` attributes
at the top level and expects that full input file paths must match instead of just file names.
Parameters
----------
file
Expand Down

0 comments on commit bc6e177

Please sign in to comment.