Skip to content

Commit

Permalink
Changed test_query_eval to use pytest fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
demianw committed Dec 11, 2024
1 parent 674d9d5 commit eec3e80
Showing 1 changed file with 66 additions and 35 deletions.
101 changes: 66 additions & 35 deletions tract_querier/tests/test_query_eval.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,12 @@
from .. import query_processor

import pytest

import ast
import numpy as np


# Ten tracts traversing random labels
another_set = True
while (another_set):
rng = np.random.default_rng()
tracts_labels = dict([(i, set(rng.integers(100, size=2))) for i in range(100)])
labels_tracts = query_processor.labels_for_tracts(tracts_labels)
another_set = 0 not in labels_tracts.keys() or 1 not in labels_tracts.keys()


tracts_in_0 = set().union(*[labels_tracts[label] for label in labels_tracts if label == 0])
tracts_in_all_but_0 = set().union(*[labels_tracts[label] for label in labels_tracts if label != 0])
tract_in_label_0_uniquely = labels_tracts[0].difference(tracts_in_all_but_0)


class DummySpatialIndexing:

def __init__(
self,
crossing_tracts_labels, crossing_labels_tracts,
Expand All @@ -33,11 +20,44 @@ def __init__(
self.label_bounding_boxes = label_bounding_boxes
self.tract_bounding_boxes = tract_bounding_boxes

dummy_spatial_indexing = DummySpatialIndexing(tracts_labels, labels_tracts, ({}, {}), ({}, {}), {}, {})
empty_spatial_indexing = DummySpatialIndexing({}, {}, ({}, {}), ({}, {}), {}, {})
@pytest.fixture
def dummy_spatial_indexing():
# Ten tracts traversing random labels
another_set = True
while (another_set):
rng = np.random.default_rng()
tracts_labels = dict([(i, set(rng.integers(100, size=2))) for i in range(100)])
labels_tracts = query_processor.labels_for_tracts(tracts_labels)
another_set = 0 not in labels_tracts.keys() or 1 not in labels_tracts.keys()

return DummySpatialIndexing(tracts_labels, labels_tracts, ({}, {}), ({}, {}), {}, {})


@pytest.fixture
def tracts_in_0(dummy_spatial_indexing):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
return set().union(*[labels_tracts[label] for label in labels_tracts if label == 0])


@pytest.fixture
def tracts_in_all_but_0(dummy_spatial_indexing: DummySpatialIndexing):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
return set().union(*[labels_tracts[label] for label in labels_tracts if label != 0])


@pytest.fixture
def tract_in_label_0_uniquely(dummy_spatial_indexing: DummySpatialIndexing, tracts_in_all_but_0):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
return labels_tracts[0].difference(tracts_in_all_but_0)


@pytest.fixture
def empty_spatial_indexing():
return DummySpatialIndexing({}, {}, ({}, {}), ({}, {}), {}, {})


def test_assign():
def test_assign(dummy_spatial_indexing):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=0"))
assert ((
Expand All @@ -47,7 +67,8 @@ def test_assign():
))


def test_assign_attr():
def test_assign_attr(dummy_spatial_indexing):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("a.left=0"))
assert ((
Expand All @@ -57,7 +78,7 @@ def test_assign_attr():
))


def test_assign_side():
def test_assign_side(empty_spatial_indexing):
query_evaluator = query_processor.EvaluateQueries(empty_spatial_indexing)

queries_labels = {
Expand Down Expand Up @@ -92,7 +113,7 @@ def test_assign_side():
assert {k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()} == queries_tracts


def test_assign_str():
def test_assign_str(empty_spatial_indexing):
query_evaluator = query_processor.EvaluateQueries(empty_spatial_indexing)

queries_labels = {
Expand Down Expand Up @@ -125,7 +146,7 @@ def test_assign_str():
assert {k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()} == queries_tracts


def test_for_list():
def test_for_list(empty_spatial_indexing):
query_evaluator = query_processor.EvaluateQueries(empty_spatial_indexing)

queries_tracts = {
Expand Down Expand Up @@ -155,7 +176,7 @@ def test_for_list():
assert {k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()} == queries_tracts


def test_for_str():
def test_for_str(empty_spatial_indexing):
query_evaluator = query_processor.EvaluateQueries(empty_spatial_indexing)

queries_tracts = {
Expand Down Expand Up @@ -185,7 +206,8 @@ def test_for_str():
assert {k: v.tracts for k, v in query_evaluator.evaluated_queries_info.items()} == queries_tracts


def test_add():
def test_add(dummy_spatial_indexing: DummySpatialIndexing):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=0+1"))
assert ((
Expand All @@ -195,7 +217,8 @@ def test_add():
))


def test_mult():
def test_mult(dummy_spatial_indexing: DummySpatialIndexing):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=0 * 1"))
assert ((
Expand All @@ -205,7 +228,8 @@ def test_mult():
))


def test_sub():
def test_sub(dummy_spatial_indexing):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=(0 + 1) - 1"))
assert ((
Expand All @@ -215,7 +239,8 @@ def test_sub():
))


def test_or():
def test_or(dummy_spatial_indexing):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=0 or 1"))
assert ((
Expand All @@ -225,7 +250,8 @@ def test_or():
))


def test_and():
def test_and(dummy_spatial_indexing):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=0 and 1"))
assert ((
Expand All @@ -235,7 +261,8 @@ def test_and():
))


def test_not_in():
def test_not_in(dummy_spatial_indexing):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=0 or 1 not in 1"))
assert ((
Expand All @@ -245,7 +272,7 @@ def test_not_in():
))


def test_only_sign():
def test_only_sign(dummy_spatial_indexing, tract_in_label_0_uniquely):
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=~0"))
assert ((
Expand All @@ -255,7 +282,7 @@ def test_only_sign():
))


def test_only():
def test_only(dummy_spatial_indexing, tract_in_label_0_uniquely):
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=only(0)"))
assert ((
Expand All @@ -265,7 +292,8 @@ def test_only():
))


def test_unsaved_query():
def test_unsaved_query(dummy_spatial_indexing):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A|=0"))
assert ((
Expand All @@ -275,7 +303,8 @@ def test_unsaved_query():
))


def test_symbolic_assignment():
def test_symbolic_assignment(dummy_spatial_indexing):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A=0; B=A"))
assert ((
Expand All @@ -285,7 +314,8 @@ def test_symbolic_assignment():
))


def test_unarySub():
def test_unarySub(dummy_spatial_indexing, tracts_in_all_but_0):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("B=0; A=-B"))
assert ((
Expand All @@ -295,7 +325,8 @@ def test_unarySub():
))


def test_not():
def test_not(dummy_spatial_indexing, tracts_in_all_but_0):
labels_tracts = dummy_spatial_indexing.crossing_labels_tracts
query_evaluator = query_processor.EvaluateQueries(dummy_spatial_indexing)
query_evaluator.visit(ast.parse("A= not 0"))
assert ((
Expand Down

0 comments on commit eec3e80

Please sign in to comment.