diff --git a/tract_querier/tests/test_query_eval.py b/tract_querier/tests/test_query_eval.py index f8f3511..7783da1 100644 --- a/tract_querier/tests/test_query_eval.py +++ b/tract_querier/tests/test_query_eval.py @@ -1,25 +1,12 @@ from .. import query_processor +import pytest + import ast import numpy as np -# Ten tracts traversing random labels -another_set = True -while (another_set): - rng = np.random.default_rng() - tracts_labels = dict([(i, set(rng.integers(100, size=2))) for i in range(100)]) - labels_tracts = query_processor.labels_for_tracts(tracts_labels) - another_set = 0 not in labels_tracts.keys() or 1 not in labels_tracts.keys() - - -tracts_in_0 = set().union(*[labels_tracts[label] for label in labels_tracts if label == 0]) -tracts_in_all_but_0 = set().union(*[labels_tracts[label] for label in labels_tracts if label != 0]) -tract_in_label_0_uniquely = labels_tracts[0].difference(tracts_in_all_but_0) - - class DummySpatialIndexing: - def __init__( self, crossing_tracts_labels, crossing_labels_tracts, @@ -33,11 +20,44 @@ def __init__( self.label_bounding_boxes = label_bounding_boxes self.tract_bounding_boxes = tract_bounding_boxes -dummy_spatial_indexing = DummySpatialIndexing(tracts_labels, labels_tracts, ({}, {}), ({}, {}), {}, {}) -empty_spatial_indexing = DummySpatialIndexing({}, {}, ({}, {}), ({}, {}), {}, {}) +@pytest.fixture +def dummy_spatial_indexing(): + # Ten tracts traversing random labels + another_set = True + while (another_set): + rng = np.random.default_rng() + tracts_labels = dict([(i, set(rng.integers(100, size=2))) for i in range(100)]) + labels_tracts = query_processor.labels_for_tracts(tracts_labels) + another_set = 0 not in labels_tracts.keys() or 1 not in labels_tracts.keys() + + return DummySpatialIndexing(tracts_labels, labels_tracts, ({}, {}), ({}, {}), {}, {}) + + +@pytest.fixture +def tracts_in_0(dummy_spatial_indexing): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts + return set().union(*[labels_tracts[label] for label in labels_tracts if label == 0]) + + +@pytest.fixture +def tracts_in_all_but_0(dummy_spatial_indexing: DummySpatialIndexing): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts + return set().union(*[labels_tracts[label] for label in labels_tracts if label != 0]) + + +@pytest.fixture +def tract_in_label_0_uniquely(dummy_spatial_indexing: DummySpatialIndexing, tracts_in_all_but_0): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts + return labels_tracts[0].difference(tracts_in_all_but_0) + + +@pytest.fixture +def empty_spatial_indexing(): + return DummySpatialIndexing({}, {}, ({}, {}), ({}, {}), {}, {}) -def test_assign(): +def test_assign(dummy_spatial_indexing): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing) query_evaluator.visit(ast.parse("A=0")) assert (( @@ -47,7 +67,8 @@ def test_assign(): )) -def test_assign_attr(): +def test_assign_attr(dummy_spatial_indexing): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing) query_evaluator.visit(ast.parse("a.left=0")) assert (( @@ -57,7 +78,7 @@ def test_assign_attr(): )) -def test_assign_side(): +def test_assign_side(empty_spatial_indexing): query_evaluator = query_processor.EvaluateQueries(empty_spatial_indexing) queries_labels = { @@ -92,7 +113,7 @@ def test_assign_side(): assert {k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()} == queries_tracts -def test_assign_str(): +def test_assign_str(empty_spatial_indexing): query_evaluator = query_processor.EvaluateQueries(empty_spatial_indexing) queries_labels = { @@ -125,7 +146,7 @@ def test_assign_str(): assert {k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()} == queries_tracts -def test_for_list(): +def test_for_list(empty_spatial_indexing): query_evaluator = query_processor.EvaluateQueries(empty_spatial_indexing) queries_tracts = { @@ -155,7 +176,7 @@ def test_for_list(): assert {k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()} == queries_tracts -def test_for_str(): +def test_for_str(empty_spatial_indexing): query_evaluator = query_processor.EvaluateQueries(empty_spatial_indexing) queries_tracts = { @@ -185,7 +206,8 @@ def test_for_str(): assert {k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()} == queries_tracts -def test_add(): +def test_add(dummy_spatial_indexing: DummySpatialIndexing): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing) query_evaluator.visit(ast.parse("A=0+1")) assert (( @@ -195,7 +217,8 @@ def test_add(): )) -def test_mult(): +def test_mult(dummy_spatial_indexing: DummySpatialIndexing): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing) query_evaluator.visit(ast.parse("A=0 * 1")) assert (( @@ -205,7 +228,8 @@ def test_mult(): )) -def test_sub(): +def test_sub(dummy_spatial_indexing): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing) query_evaluator.visit(ast.parse("A=(0 + 1) - 1")) assert (( @@ -215,7 +239,8 @@ def test_sub(): )) -def test_or(): +def test_or(dummy_spatial_indexing): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing) query_evaluator.visit(ast.parse("A=0 or 1")) assert (( @@ -225,7 +250,8 @@ def test_or(): )) -def test_and(): +def test_and(dummy_spatial_indexing): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing) query_evaluator.visit(ast.parse("A=0 and 1")) assert (( @@ -235,7 +261,8 @@ def test_and(): )) -def test_not_in(): +def test_not_in(dummy_spatial_indexing): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing) query_evaluator.visit(ast.parse("A=0 or 1 not in 1")) assert (( @@ -245,7 +272,7 @@ def test_not_in(): )) -def test_only_sign(): +def test_only_sign(dummy_spatial_indexing, tract_in_label_0_uniquely): query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing) query_evaluator.visit(ast.parse("A=~0")) assert (( @@ -255,7 +282,7 @@ def test_only_sign(): )) -def test_only(): +def test_only(dummy_spatial_indexing, tract_in_label_0_uniquely): query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing) query_evaluator.visit(ast.parse("A=only(0)")) assert (( @@ -265,7 +292,8 @@ def test_only(): )) -def test_unsaved_query(): +def test_unsaved_query(dummy_spatial_indexing): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing) query_evaluator.visit(ast.parse("A|=0")) assert (( @@ -275,7 +303,8 @@ def test_unsaved_query(): )) -def test_symbolic_assignment(): +def test_symbolic_assignment(dummy_spatial_indexing): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing) query_evaluator.visit(ast.parse("A=0; B=A")) assert (( @@ -285,7 +314,8 @@ def test_symbolic_assignment(): )) -def test_unarySub(): +def test_unarySub(dummy_spatial_indexing, tracts_in_all_but_0): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing) query_evaluator.visit(ast.parse("B=0; A=-B")) assert (( @@ -295,7 +325,8 @@ def test_unarySub(): )) -def test_not(): +def test_not(dummy_spatial_indexing, tracts_in_all_but_0): + labels_tracts = dummy_spatial_indexing.crossing_labels_tracts query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing) query_evaluator.visit(ast.parse("A= not 0")) assert ((