Skip to content

Commit

Permalink
increase test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
yagebu committed Oct 29, 2024
1 parent 3a0900c commit 442ff24
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 20 deletions.
34 changes: 26 additions & 8 deletions src/fava/beans/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,46 @@
from fava.beans.abc import Directive


def parent(acc: str) -> str | None:
def parent(account: str) -> str | None:
"""Get the name of the parent of the given account."""
parts = acc.rsplit(":", maxsplit=1)
parts = account.rsplit(":", maxsplit=1)
return parts[0] if len(parts) == 2 else None


def root(acc: str) -> str:
def root(account: str) -> str:
"""Get root account of the given account."""
parts = acc.split(":", maxsplit=1)
parts = account.split(":", maxsplit=1)
return parts[0]


def child_account_tester(acc: str) -> Callable[[str], bool]:
def child_account_tester(account: str) -> Callable[[str], bool]:
"""Get a function to check if an account is a descendant of the account."""
acc_as_parent = acc + ":"
account_as_parent = account + ":"

def is_child_account(a: str) -> bool:
return a == acc or a.startswith(acc_as_parent)
def is_child_account(other: str) -> bool:
return other == account or other.startswith(account_as_parent)

return is_child_account


def account_tester(
account: str, *, with_children: bool
) -> Callable[[str], bool]:
"""Get a function to check if an account is equal to the account.
Arguments:
account: An account name to check.
with_children: Whether to include all child accounts.
"""
if with_children:
return child_account_tester(account)

def is_account(other: str) -> bool:
return other == account

return is_account


def get_entry_accounts(entry: Directive) -> list[str]:
"""Accounts for an entry.
Expand Down
15 changes: 5 additions & 10 deletions src/fava/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import timedelta
from functools import cached_property
from functools import lru_cache
from itertools import takewhile
from pathlib import Path
from typing import TYPE_CHECKING

Expand All @@ -16,7 +17,7 @@
from fava.beans.abc import Balance
from fava.beans.abc import Price
from fava.beans.abc import Transaction
from fava.beans.account import child_account_tester
from fava.beans.account import account_tester
from fava.beans.account import get_entry_accounts
from fava.beans.funcs import get_position
from fava.beans.funcs import hash_entry
Expand Down Expand Up @@ -459,12 +460,8 @@ def account_journal(
Yields:
Tuples of ``(entry, change, balance)``.
"""

def is_account(a: str) -> bool:
return a == account_name

relevant_account = (
child_account_tester(account_name) if with_children else is_account
relevant_account = account_tester(
account_name, with_children=with_children
)

prices = self.prices
Expand Down Expand Up @@ -539,9 +536,7 @@ def context(

entry_accounts = get_entry_accounts(entry)
balances = {account: Inventory() for account in entry_accounts}
for entry_ in self.all_entries:
if entry_ is entry:
break
for entry_ in takewhile(lambda e: e is not entry, self.all_entries):
if isinstance(entry_, Transaction):
for posting in entry_.postings:
balance = balances.get(posting.account, None)
Expand Down
4 changes: 2 additions & 2 deletions src/fava/core/charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from fava.beans.abc import Amount
from fava.beans.abc import Position
from fava.beans.abc import Transaction
from fava.beans.account import child_account_tester
from fava.beans.account import account_tester
from fava.beans.flags import FLAG_UNREALIZED
from fava.core.conversion import cost_or_value
from fava.core.inventory import CounterInventory
Expand Down Expand Up @@ -229,7 +229,7 @@ def linechart(
def _balances() -> Iterable[tuple[date, CounterInventory]]:
last_date = None
running_balance = CounterInventory()
is_child_account = child_account_tester(account_name)
is_child_account = account_tester(account_name, with_children=True)

for entry in filtered.entries:
for posting in getattr(entry, "postings", []):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_beans.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fava.beans import create
from fava.beans.abc import Note
from fava.beans.abc import Price
from fava.beans.account import account_tester
from fava.beans.account import parent
from fava.beans.account import root
from fava.beans.funcs import get_position
Expand All @@ -26,10 +27,27 @@ def test_account_parent() -> None:
assert parent("Assets:Cash") == "Assets"
assert parent("Assets:Cash:AA") == "Assets:Cash"
assert parent("Assets:asdfasdf") == "Assets"


def test_account_root() -> None:
assert root("Assets:asdfasdf:asdfasdf") == "Assets"
assert root("Assets:asdfasdf") == "Assets"


def test_account_tester() -> None:
is_child = account_tester("Assets:Cash", with_children=True)
assert not is_child("Assets")
assert not is_child("Assets:CashOther")
assert is_child("Assets:Cash")
assert is_child("Assets:Cash:Test")

is_equal = account_tester("Assets:Cash", with_children=False)
assert not is_equal("Assets")
assert not is_equal("Assets:CashOther")
assert is_equal("Assets:Cash")
assert not is_equal("Assets:Cash:Test")


def test_hash_entry() -> None:
date = datetime.date(2022, 4, 2)
note = create.note(
Expand Down
6 changes: 6 additions & 0 deletions tests/test_core_charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def test_linechart_data(
)
snapshot(data, json=True)

assert not example_ledger.charts.linechart(
filtered,
"Assets:Testing:MultipleCommodities:NotAnAccount",
"units",
)


def test_net_worth(example_ledger: FavaLedger, snapshot: SnapshotFunc) -> None:
filtered = example_ledger.get_filtered()
Expand Down
11 changes: 11 additions & 0 deletions tests/test_json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,17 @@ def test_api_context(
HTTPStatus.BAD_REQUEST,
)

balance_entry_hash = hash_entry(
example_ledger.all_entries_by_type.Balance[0]
)
response = test_client.get(
"/long-example/api/context",
query_string={"entry_hash": balance_entry_hash},
)
data = assert_api_success(response)
assert data["balances_before"]
assert not data["balances_after"]

entry_hash = hash_entry(
next(
entry
Expand Down

0 comments on commit 442ff24

Please sign in to comment.