diff --git a/src/fava/beans/account.py b/src/fava/beans/account.py index 5a6c67e7a..98e8b8b29 100644 --- a/src/fava/beans/account.py +++ b/src/fava/beans/account.py @@ -16,28 +16,46 @@ from fava.beans.abc import Directive -def parent(acc: str) -> str | None: +def parent(account: str) -> str | None: """Get the name of the parent of the given account.""" - parts = acc.rsplit(":", maxsplit=1) + parts = account.rsplit(":", maxsplit=1) return parts[0] if len(parts) == 2 else None -def root(acc: str) -> str: +def root(account: str) -> str: """Get root account of the given account.""" - parts = acc.split(":", maxsplit=1) + parts = account.split(":", maxsplit=1) return parts[0] -def child_account_tester(acc: str) -> Callable[[str], bool]: +def child_account_tester(account: str) -> Callable[[str], bool]: """Get a function to check if an account is a descendant of the account.""" - acc_as_parent = acc + ":" + account_as_parent = account + ":" - def is_child_account(a: str) -> bool: - return a == acc or a.startswith(acc_as_parent) + def is_child_account(other: str) -> bool: + return other == account or other.startswith(account_as_parent) return is_child_account +def account_tester( + account: str, *, with_children: bool +) -> Callable[[str], bool]: + """Get a function to check if an account is equal to the account. + + Arguments: + account: An account name to check. + with_children: Whether to include all child accounts. + """ + if with_children: + return child_account_tester(account) + + def is_account(other: str) -> bool: + return other == account + + return is_account + + def get_entry_accounts(entry: Directive) -> list[str]: """Accounts for an entry. diff --git a/src/fava/core/__init__.py b/src/fava/core/__init__.py index f12d8b66c..97f2e707e 100644 --- a/src/fava/core/__init__.py +++ b/src/fava/core/__init__.py @@ -6,6 +6,7 @@ from datetime import timedelta from functools import cached_property from functools import lru_cache +from itertools import takewhile from pathlib import Path from typing import TYPE_CHECKING @@ -16,7 +17,7 @@ from fava.beans.abc import Balance from fava.beans.abc import Price from fava.beans.abc import Transaction -from fava.beans.account import child_account_tester +from fava.beans.account import account_tester from fava.beans.account import get_entry_accounts from fava.beans.funcs import get_position from fava.beans.funcs import hash_entry @@ -459,12 +460,8 @@ def account_journal( Yields: Tuples of ``(entry, change, balance)``. """ - - def is_account(a: str) -> bool: - return a == account_name - - relevant_account = ( - child_account_tester(account_name) if with_children else is_account + relevant_account = account_tester( + account_name, with_children=with_children ) prices = self.prices @@ -539,9 +536,7 @@ def context( entry_accounts = get_entry_accounts(entry) balances = {account: Inventory() for account in entry_accounts} - for entry_ in self.all_entries: - if entry_ is entry: - break + for entry_ in takewhile(lambda e: e is not entry, self.all_entries): if isinstance(entry_, Transaction): for posting in entry_.postings: balance = balances.get(posting.account, None) diff --git a/src/fava/core/charts.py b/src/fava/core/charts.py index 88f2f8b90..d56eb6a35 100644 --- a/src/fava/core/charts.py +++ b/src/fava/core/charts.py @@ -23,7 +23,7 @@ from fava.beans.abc import Amount from fava.beans.abc import Position from fava.beans.abc import Transaction -from fava.beans.account import child_account_tester +from fava.beans.account import account_tester from fava.beans.flags import FLAG_UNREALIZED from fava.core.conversion import cost_or_value from fava.core.inventory import CounterInventory @@ -229,7 +229,7 @@ def linechart( def _balances() -> Iterable[tuple[date, CounterInventory]]: last_date = None running_balance = CounterInventory() - is_child_account = child_account_tester(account_name) + is_child_account = account_tester(account_name, with_children=True) for entry in filtered.entries: for posting in getattr(entry, "postings", []): diff --git a/tests/test_beans.py b/tests/test_beans.py index d8afc527a..ae446b836 100644 --- a/tests/test_beans.py +++ b/tests/test_beans.py @@ -10,6 +10,7 @@ from fava.beans import create from fava.beans.abc import Note from fava.beans.abc import Price +from fava.beans.account import account_tester from fava.beans.account import parent from fava.beans.account import root from fava.beans.funcs import get_position @@ -26,10 +27,27 @@ def test_account_parent() -> None: assert parent("Assets:Cash") == "Assets" assert parent("Assets:Cash:AA") == "Assets:Cash" assert parent("Assets:asdfasdf") == "Assets" + + +def test_account_root() -> None: assert root("Assets:asdfasdf:asdfasdf") == "Assets" assert root("Assets:asdfasdf") == "Assets" +def test_account_tester() -> None: + is_child = account_tester("Assets:Cash", with_children=True) + assert not is_child("Assets") + assert not is_child("Assets:CashOther") + assert is_child("Assets:Cash") + assert is_child("Assets:Cash:Test") + + is_equal = account_tester("Assets:Cash", with_children=False) + assert not is_equal("Assets") + assert not is_equal("Assets:CashOther") + assert is_equal("Assets:Cash") + assert not is_equal("Assets:Cash:Test") + + def test_hash_entry() -> None: date = datetime.date(2022, 4, 2) note = create.note( diff --git a/tests/test_core_charts.py b/tests/test_core_charts.py index a223f3cca..c148f3d08 100644 --- a/tests/test_core_charts.py +++ b/tests/test_core_charts.py @@ -56,6 +56,12 @@ def test_linechart_data( ) snapshot(data, json=True) + assert not example_ledger.charts.linechart( + filtered, + "Assets:Testing:MultipleCommodities:NotAnAccount", + "units", + ) + def test_net_worth(example_ledger: FavaLedger, snapshot: SnapshotFunc) -> None: filtered = example_ledger.get_filtered() diff --git a/tests/test_json_api.py b/tests/test_json_api.py index 8823b3991..4530fd7e1 100644 --- a/tests/test_json_api.py +++ b/tests/test_json_api.py @@ -263,6 +263,17 @@ def test_api_context( HTTPStatus.BAD_REQUEST, ) + balance_entry_hash = hash_entry( + example_ledger.all_entries_by_type.Balance[0] + ) + response = test_client.get( + "/long-example/api/context", + query_string={"entry_hash": balance_entry_hash}, + ) + data = assert_api_success(response) + assert data["balances_before"] + assert not data["balances_after"] + entry_hash = hash_entry( next( entry