Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utility kwargs #5

Merged
merged 5 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 0 additions & 93 deletions src/upper_envelope/interpolation.py

This file was deleted.

24 changes: 2 additions & 22 deletions src/upper_envelope/shared.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,14 @@
import functools
import inspect
from functools import partial


def determine_function_arguments_and_partial_options(func, options):
def process_function_args_to_kwargs(func):
signature = set(inspect.signature(func).parameters)
(
partialed_func,
signature,
) = partial_options_and_addtional_arguments_and_update_signature(
func=func,
signature=signature,
options=options,
)

@functools.wraps(func)
def processed_func(**kwargs):
func_kwargs = {key: kwargs[key] for key in signature}

return partialed_func(**func_kwargs)
return func(**func_kwargs)

return processed_func


def partial_options_and_addtional_arguments_and_update_signature(
func, signature, options
):
"""Partial in options and update signature."""
if "options" in signature:
func = partial(func, options=options)
signature = signature - {"options"}

return func, signature
26 changes: 13 additions & 13 deletions src/upper_envelope/upper_envelope_jax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Jax implementation of the extended Fast Upper-Envelope Scan (FUES).

The original FUES algorithm is based on Loretti I. Dobrescu and Akshay Shanker (2022)
'Fast Upper-Envelope Scan for Discrete-Continuous Dynamic Programming',
'Fast Upper-Envelope Scan for Solving Dynamic Optimization Problems',
https://dx.doi.org/10.2139/ssrn.4181302

"""
Expand Down Expand Up @@ -63,7 +63,7 @@ def fast_upper_envelope_wrapper(
expected_value_zero_savings (float): The agent's expected value given that she
saves zero.
choice (int): The current choice.
compute_value (callable): Function to compute the agent's value.
compute_value (callable): Function to compute the agent's utility.
params (dict): Dictionary containing the model parameters.

Returns:
Expand Down Expand Up @@ -128,17 +128,6 @@ def fast_upper_envelope_wrapper(
)


def _compute_value(
consumption, next_period_value, state_choice_vec, params, compute_utility
):
utility = compute_utility(
consumption=consumption,
params=params,
**state_choice_vec,
)
return utility + params["beta"] * next_period_value


def fast_upper_envelope(
endog_grid: jnp.ndarray,
value: jnp.ndarray,
Expand Down Expand Up @@ -1367,3 +1356,14 @@ def create_indicator_if_value_function_is_switched(
is_switched = gradient_exog_abs > jump_thresh

return is_switched


def _compute_value(
consumption, next_period_value, state_choice_vec, params, compute_utility
):
utility = compute_utility(
consumption=consumption,
params=params,
**state_choice_vec,
)
return utility + params["beta"] * next_period_value
39 changes: 24 additions & 15 deletions src/upper_envelope/upper_envelope_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

"""
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Tuple

Expand All @@ -19,8 +20,9 @@ def fast_upper_envelope_wrapper(
value: np.ndarray,
exog_grid: np.ndarray,
expected_value_zero_savings: float,
choice: int,
compute_value: Callable,
state_choice_vec: np.ndarray,
params: Dict[str, float],
compute_utility: Callable,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Drop suboptimal points and refine the endogenous grid, policy, and value.

Expand Down Expand Up @@ -57,7 +59,7 @@ def fast_upper_envelope_wrapper(
expected_value_zero_savings (float): The agent's expected value given that she
saves zero.
choice (int): The current choice.
compute_value (callable): Function to compute the agent's value.
compute_utility (callable): Function to compute the agent's utility.

Returns:
tuple:
Expand All @@ -70,8 +72,12 @@ def fast_upper_envelope_wrapper(
containing refined state- and choice-specific value function.

"""
n_grid_wealth = len(exog_grid)
n_grid_wealth = len(endog_grid)
min_wealth_grid = np.min(endog_grid)
# exog_grid = np.append(
# 0, np.linspace(min_wealth_grid, endog_grid[-1], n_grid_wealth - 1)
# )

if endog_grid[0] > min_wealth_grid:
# Non-concave region coincides with credit constraint.
# This happens when there is a non-monotonicity in the endogenous wealth grid
Expand All @@ -83,11 +89,12 @@ def fast_upper_envelope_wrapper(
endog_grid=endog_grid,
value=value,
policy=policy,
choice=choice,
state_choice_vec=state_choice_vec,
expected_value_zero_savings=expected_value_zero_savings,
min_wealth_grid=min_wealth_grid,
n_grid_wealth=n_grid_wealth,
compute_value=compute_value,
params=params,
compute_utility=compute_utility,
)
exog_grid = np.append(np.zeros(n_grid_wealth // 10 - 1), exog_grid)

Expand Down Expand Up @@ -725,11 +732,12 @@ def _augment_grids(
endog_grid: np.ndarray,
value: np.ndarray,
policy: np.ndarray,
choice: int,
state_choice_vec: np.ndarray,
expected_value_zero_savings: np.ndarray,
min_wealth_grid: float,
n_grid_wealth: int,
compute_value: Callable,
compute_utility: Callable,
params: Dict[str, float],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Extends the endogenous wealth grid, value, and policy functions to the left.

Expand All @@ -753,9 +761,8 @@ def _augment_grids(
choice (int): The agent's choice.
expected_value_zero_savings (float): The agent's expected value given that she
saves zero.
min_wealth_grid (float): Minimal wealth level in the endogenous wealth grid.
n_grid_wealth (int): Number of grid points in the exogenous wealth grid.
compute_value (callable): Function to compute the agent's value.
compute_utility (callable): Function to compute the agent's utility.

Returns:
tuple:
Expand All @@ -772,12 +779,14 @@ def _augment_grids(
min_wealth_grid, endog_grid[0], n_grid_wealth // 10
)[:-1]

grid_augmented = np.append(grid_points_to_add, endog_grid)
values_to_add = compute_value(
grid_points_to_add,
expected_value_zero_savings,
choice,
utility = compute_utility(
consumption=grid_points_to_add,
params=params,
**state_choice_vec,
)
values_to_add = utility + params["beta"] * expected_value_zero_savings

grid_augmented = np.append(grid_points_to_add, endog_grid)
value_augmented = np.append(values_to_add, value)
policy_augmented = np.append(grid_points_to_add, policy)

Expand Down
8 changes: 0 additions & 8 deletions tests/resources/replication_tests/deaton/options.yaml

This file was deleted.

14 changes: 0 additions & 14 deletions tests/resources/replication_tests/deaton/params.csv

This file was deleted.

Binary file not shown.
Binary file not shown.

This file was deleted.

This file was deleted.

Binary file not shown.
Binary file not shown.

This file was deleted.

This file was deleted.

Binary file not shown.
Binary file not shown.
Loading
Loading