diff --git a/src/upper_envelope/interpolation.py b/src/upper_envelope/interpolation.py deleted file mode 100644 index c4765b5..0000000 --- a/src/upper_envelope/interpolation.py +++ /dev/null @@ -1,93 +0,0 @@ -import jax.numpy as jnp - - -def linear_interpolation_formula( - y_high: float | jnp.ndarray, - y_low: float | jnp.ndarray, - x_high: float | jnp.ndarray, - x_low: float | jnp.ndarray, - x_new: float | jnp.ndarray, -): - """Linear interpolation formula.""" - interpolate_dist = x_new - x_low - interpolate_slope = (y_high - y_low) / (x_high - x_low) - interpol_res = (interpolate_slope * interpolate_dist) + y_low - - return interpol_res - - -def interpolate_policy_and_value_on_wealth_grid( - wealth_beginning_of_period: jnp.ndarray, - endog_wealth_grid: jnp.ndarray, - policy_left_grid: jnp.ndarray, - policy_right_grid: jnp.ndarray, - value_grid: jnp.ndarray, -): - """Interpolate policy and value functions on the wealth grid. - - This function uses the left and right policy function. - For a more detailed description, see calc_intersection_and_extrapolate_policy - in fast_upper_envelope.py. - - Args: - wealth_beginning_of_period (jnp.ndarray): 1d array of shape (n,) containing the - begin of period wealth. - endog_wealth_grid (jnp.array): 1d array of shape (n,) containing the endogenous - wealth grid. - policy_left_grid (jnp.ndarray): 1d array of shape (n,) containing the - left policy function corresponding to the endogenous wealth grid. - policy_right_grid (jnp.ndarray): 1d array of shape (n,) containing the - left policy function corresponding to the endogenous wealth grid. - value_grid (jnp.ndarray): 1d array of shape (n,) containing the value function - values corresponding to the endogenous wealth grid. - - Returns: - tuple: - - - policy_new (jnp.ndarray): 1d array of shape (n,) containing the interpolated - policy function values corresponding to the begin of period wealth. - - value_new (jnp.ndarray): 1d array of shape (n,) containing the interpolated - value function values corresponding to the begin of period wealth. - - """ - ind_high, ind_low = get_index_high_and_low( - x=endog_wealth_grid, x_new=wealth_beginning_of_period - ) - - wealth_low = jnp.take(endog_wealth_grid, ind_low) - wealth_high = jnp.take(endog_wealth_grid, ind_high) - - policy_new = linear_interpolation_formula( - y_high=jnp.take(policy_left_grid, ind_high), - y_low=jnp.take(policy_right_grid, ind_low), - x_high=wealth_high, - x_low=wealth_low, - x_new=wealth_beginning_of_period, - ) - - value_new = linear_interpolation_formula( - y_high=jnp.take(value_grid, ind_high), - y_low=jnp.take(value_grid, ind_low), - x_high=wealth_high, - x_low=wealth_low, - x_new=wealth_beginning_of_period, - ) - - return policy_new, value_new - - -def get_index_high_and_low(x, x_new): - """Get index of the highest value in x that is smaller than x_new. - - Args: - x (np.ndarray): 1d array of shape (n,) containing the x-values. - x_new (float): The new x-value at which to evaluate the interpolation function. - - Returns: - int: Index of the value in the wealth grid which is higher than x_new. Or in - case of extrapolation last or first index of not nan element. - - """ - ind_high = jnp.searchsorted(x, x_new).clip(max=(x.shape[0] - 1), min=1) - ind_high -= jnp.isnan(x[ind_high]).astype(int) - return ind_high, ind_high - 1 diff --git a/src/upper_envelope/shared.py b/src/upper_envelope/shared.py index 9908387..8341685 100644 --- a/src/upper_envelope/shared.py +++ b/src/upper_envelope/shared.py @@ -1,34 +1,14 @@ import functools import inspect -from functools import partial -def determine_function_arguments_and_partial_options(func, options): +def process_function_args_to_kwargs(func): signature = set(inspect.signature(func).parameters) - ( - partialed_func, - signature, - ) = partial_options_and_addtional_arguments_and_update_signature( - func=func, - signature=signature, - options=options, - ) @functools.wraps(func) def processed_func(**kwargs): func_kwargs = {key: kwargs[key] for key in signature} - return partialed_func(**func_kwargs) + return func(**func_kwargs) return processed_func - - -def partial_options_and_addtional_arguments_and_update_signature( - func, signature, options -): - """Partial in options and update signature.""" - if "options" in signature: - func = partial(func, options=options) - signature = signature - {"options"} - - return func, signature diff --git a/src/upper_envelope/upper_envelope_jax.py b/src/upper_envelope/upper_envelope_jax.py index 9be268f..60bcb34 100644 --- a/src/upper_envelope/upper_envelope_jax.py +++ b/src/upper_envelope/upper_envelope_jax.py @@ -1,7 +1,7 @@ """Jax implementation of the extended Fast Upper-Envelope Scan (FUES). The original FUES algorithm is based on Loretti I. Dobrescu and Akshay Shanker (2022) -'Fast Upper-Envelope Scan for Discrete-Continuous Dynamic Programming', +'Fast Upper-Envelope Scan for Solving Dynamic Optimization Problems', https://dx.doi.org/10.2139/ssrn.4181302 """ @@ -63,7 +63,7 @@ def fast_upper_envelope_wrapper( expected_value_zero_savings (float): The agent's expected value given that she saves zero. choice (int): The current choice. - compute_value (callable): Function to compute the agent's value. + compute_value (callable): Function to compute the agent's utility. params (dict): Dictionary containing the model parameters. Returns: @@ -128,17 +128,6 @@ def fast_upper_envelope_wrapper( ) -def _compute_value( - consumption, next_period_value, state_choice_vec, params, compute_utility -): - utility = compute_utility( - consumption=consumption, - params=params, - **state_choice_vec, - ) - return utility + params["beta"] * next_period_value - - def fast_upper_envelope( endog_grid: jnp.ndarray, value: jnp.ndarray, @@ -1367,3 +1356,14 @@ def create_indicator_if_value_function_is_switched( is_switched = gradient_exog_abs > jump_thresh return is_switched + + +def _compute_value( + consumption, next_period_value, state_choice_vec, params, compute_utility +): + utility = compute_utility( + consumption=consumption, + params=params, + **state_choice_vec, + ) + return utility + params["beta"] * next_period_value diff --git a/src/upper_envelope/upper_envelope_numba.py b/src/upper_envelope/upper_envelope_numba.py index f74af1b..f269409 100644 --- a/src/upper_envelope/upper_envelope_numba.py +++ b/src/upper_envelope/upper_envelope_numba.py @@ -6,6 +6,7 @@ """ from typing import Callable +from typing import Dict from typing import Optional from typing import Tuple @@ -19,8 +20,9 @@ def fast_upper_envelope_wrapper( value: np.ndarray, exog_grid: np.ndarray, expected_value_zero_savings: float, - choice: int, - compute_value: Callable, + state_choice_vec: np.ndarray, + params: Dict[str, float], + compute_utility: Callable, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Drop suboptimal points and refine the endogenous grid, policy, and value. @@ -57,7 +59,7 @@ def fast_upper_envelope_wrapper( expected_value_zero_savings (float): The agent's expected value given that she saves zero. choice (int): The current choice. - compute_value (callable): Function to compute the agent's value. + compute_utility (callable): Function to compute the agent's utility. Returns: tuple: @@ -70,8 +72,12 @@ def fast_upper_envelope_wrapper( containing refined state- and choice-specific value function. """ - n_grid_wealth = len(exog_grid) + n_grid_wealth = len(endog_grid) min_wealth_grid = np.min(endog_grid) + # exog_grid = np.append( + # 0, np.linspace(min_wealth_grid, endog_grid[-1], n_grid_wealth - 1) + # ) + if endog_grid[0] > min_wealth_grid: # Non-concave region coincides with credit constraint. # This happens when there is a non-monotonicity in the endogenous wealth grid @@ -83,11 +89,12 @@ def fast_upper_envelope_wrapper( endog_grid=endog_grid, value=value, policy=policy, - choice=choice, + state_choice_vec=state_choice_vec, expected_value_zero_savings=expected_value_zero_savings, min_wealth_grid=min_wealth_grid, n_grid_wealth=n_grid_wealth, - compute_value=compute_value, + params=params, + compute_utility=compute_utility, ) exog_grid = np.append(np.zeros(n_grid_wealth // 10 - 1), exog_grid) @@ -725,11 +732,12 @@ def _augment_grids( endog_grid: np.ndarray, value: np.ndarray, policy: np.ndarray, - choice: int, + state_choice_vec: np.ndarray, expected_value_zero_savings: np.ndarray, min_wealth_grid: float, n_grid_wealth: int, - compute_value: Callable, + compute_utility: Callable, + params: Dict[str, float], ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Extends the endogenous wealth grid, value, and policy functions to the left. @@ -753,9 +761,8 @@ def _augment_grids( choice (int): The agent's choice. expected_value_zero_savings (float): The agent's expected value given that she saves zero. - min_wealth_grid (float): Minimal wealth level in the endogenous wealth grid. n_grid_wealth (int): Number of grid points in the exogenous wealth grid. - compute_value (callable): Function to compute the agent's value. + compute_utility (callable): Function to compute the agent's utility. Returns: tuple: @@ -772,12 +779,14 @@ def _augment_grids( min_wealth_grid, endog_grid[0], n_grid_wealth // 10 )[:-1] - grid_augmented = np.append(grid_points_to_add, endog_grid) - values_to_add = compute_value( - grid_points_to_add, - expected_value_zero_savings, - choice, + utility = compute_utility( + consumption=grid_points_to_add, + params=params, + **state_choice_vec, ) + values_to_add = utility + params["beta"] * expected_value_zero_savings + + grid_augmented = np.append(grid_points_to_add, endog_grid) value_augmented = np.append(values_to_add, value) policy_augmented = np.append(grid_points_to_add, policy) diff --git a/tests/resources/replication_tests/deaton/options.yaml b/tests/resources/replication_tests/deaton/options.yaml deleted file mode 100644 index 744dc57..0000000 --- a/tests/resources/replication_tests/deaton/options.yaml +++ /dev/null @@ -1,8 +0,0 @@ ---- -n_periods: 25 -min_age: 20 -n_discrete_choices: 1 -n_grid_points: 100 -max_wealth: 75 -quadrature_points_stochastic: 10 -n_simulations: 10 diff --git a/tests/resources/replication_tests/deaton/params.csv b/tests/resources/replication_tests/deaton/params.csv deleted file mode 100644 index 5400654..0000000 --- a/tests/resources/replication_tests/deaton/params.csv +++ /dev/null @@ -1,14 +0,0 @@ -category,name,value,comment -beta,beta,0.95,discount factor -delta,delta,0,disutility of work -utility_function,rho,1,CRRA coefficient -wage,constant,0.75,age-independent labor income -wage,exp,0.04,return to experience -wage,exp_squared,-0.0004,return to experience squared -shocks,sigma,0.25,shock on labor income sigma parameter/standard deviation -shocks,lambda,2.2204e-16,taste shock (scale) parameter -assets,interest_rate,0.05,interest rate on capital -assets,initial_wealth_low,0,lowest level of initial wealth (relevant for simulation) -assets,initial_wealth_high,30,highest level of initial wealth (relevant for simulation) -assets,max_wealth,75,maximum level of wealth -assets,consumption_floor,0.0,consumption floor/retirement safety net (only relevant in the dc-egm retirement model) diff --git a/tests/resources/replication_tests/deaton/policy.pkl b/tests/resources/replication_tests/deaton/policy.pkl deleted file mode 100644 index 200b5aa..0000000 Binary files a/tests/resources/replication_tests/deaton/policy.pkl and /dev/null differ diff --git a/tests/resources/replication_tests/deaton/value.pkl b/tests/resources/replication_tests/deaton/value.pkl deleted file mode 100644 index c43efe4..0000000 Binary files a/tests/resources/replication_tests/deaton/value.pkl and /dev/null differ diff --git a/tests/resources/replication_tests/retirement_no_taste_shocks/options.yaml b/tests/resources/replication_tests/retirement_no_taste_shocks/options.yaml deleted file mode 100644 index 12a3d60..0000000 --- a/tests/resources/replication_tests/retirement_no_taste_shocks/options.yaml +++ /dev/null @@ -1,8 +0,0 @@ ---- -n_periods: 25 -min_age: 20 -n_discrete_choices: 2 -n_grid_points: 500 -max_wealth: 50 -quadrature_points_stochastic: 5 -n_simulations: 10 diff --git a/tests/resources/replication_tests/retirement_no_taste_shocks/params.csv b/tests/resources/replication_tests/retirement_no_taste_shocks/params.csv deleted file mode 100644 index 1d90dab..0000000 --- a/tests/resources/replication_tests/retirement_no_taste_shocks/params.csv +++ /dev/null @@ -1,14 +0,0 @@ -category,name,value,comment -beta,beta,0.95,discount factor -delta,delta,0.35,disutility of work -utility_function,rho,1.95,CRRA coefficient -wage,constant,0.75,age-independent labor income -wage,exp,0.04,return to experience -wage,exp_squared,-0.0002,return to experience squared -shocks,sigma,0.00,shock on labor income sigma parameter/standard deviation -shocks,lambda,2.2204e-16,taste shock (scale) parameter -assets,interest_rate,0.05,interest rate on capital -assets,initial_wealth_low,0,lowest level of initial wealth (relevant for simulation) -assets,initial_wealth_high,30,highest level of initial wealth (relevant for simulation) -assets,max_wealth,50,maximum level of wealth -assets,consumption_floor,0.001,consumption floor/retirement safety net (only relevant in the dc-egm retirement model) diff --git a/tests/resources/replication_tests/retirement_no_taste_shocks/policy.pkl b/tests/resources/replication_tests/retirement_no_taste_shocks/policy.pkl deleted file mode 100644 index 42cc11b..0000000 Binary files a/tests/resources/replication_tests/retirement_no_taste_shocks/policy.pkl and /dev/null differ diff --git a/tests/resources/replication_tests/retirement_no_taste_shocks/value.pkl b/tests/resources/replication_tests/retirement_no_taste_shocks/value.pkl deleted file mode 100644 index bc167c9..0000000 Binary files a/tests/resources/replication_tests/retirement_no_taste_shocks/value.pkl and /dev/null differ diff --git a/tests/resources/replication_tests/retirement_taste_shocks/options.yaml b/tests/resources/replication_tests/retirement_taste_shocks/options.yaml deleted file mode 100644 index 12a3d60..0000000 --- a/tests/resources/replication_tests/retirement_taste_shocks/options.yaml +++ /dev/null @@ -1,8 +0,0 @@ ---- -n_periods: 25 -min_age: 20 -n_discrete_choices: 2 -n_grid_points: 500 -max_wealth: 50 -quadrature_points_stochastic: 5 -n_simulations: 10 diff --git a/tests/resources/replication_tests/retirement_taste_shocks/params.csv b/tests/resources/replication_tests/retirement_taste_shocks/params.csv deleted file mode 100644 index 897e885..0000000 --- a/tests/resources/replication_tests/retirement_taste_shocks/params.csv +++ /dev/null @@ -1,14 +0,0 @@ -category,name,value,comment -beta,beta,0.9523809523809523,discount factor -delta,delta,0.35,disutility of work -utility_function,rho,1.95,CRRA coefficient -wage,constant,0.75,age-independent labor income -wage,exp,0.04,return to experience -wage,exp_squared,-0.0002,return to experience squared -shocks,sigma,0.35,shock on labor income sigma parameter/standard deviation -shocks,lambda,0.2,taste shock (scale) parameter -assets,interest_rate,0.05,interest rate on capital -assets,initial_wealth_low,0,lowest level of initial wealth (relevant for simulation) -assets,initial_wealth_high,30,highest level of initial wealth (relevant for simulation) -assets,max_wealth,50,maximum level of wealth -assets,consumption_floor,0.001,consumption floor/retirement safety net (only relevant in the dc-egm retirement model) diff --git a/tests/resources/replication_tests/retirement_taste_shocks/policy.pkl b/tests/resources/replication_tests/retirement_taste_shocks/policy.pkl deleted file mode 100644 index acb1559..0000000 Binary files a/tests/resources/replication_tests/retirement_taste_shocks/policy.pkl and /dev/null differ diff --git a/tests/resources/replication_tests/retirement_taste_shocks/value.pkl b/tests/resources/replication_tests/retirement_taste_shocks/value.pkl deleted file mode 100644 index aa455a8..0000000 Binary files a/tests/resources/replication_tests/retirement_taste_shocks/value.pkl and /dev/null differ diff --git a/tests/test_upper_envelope_jax.py b/tests/test_upper_envelope_jax.py index 8fb3600..3a5d486 100644 --- a/tests/test_upper_envelope_jax.py +++ b/tests/test_upper_envelope_jax.py @@ -6,15 +6,15 @@ import numpy as np import pytest from numpy.testing import assert_array_almost_equal as aaae -from upper_envelope.interpolation import interpolate_policy_and_value_on_wealth_grid -from upper_envelope.shared import determine_function_arguments_and_partial_options +from upper_envelope.shared import process_function_args_to_kwargs from upper_envelope.upper_envelope_jax import fast_upper_envelope from upper_envelope.upper_envelope_jax import ( fast_upper_envelope_wrapper, ) from tests.utils.fast_upper_envelope_org import fast_upper_envelope_wrapper_org -from tests.utils.interpolations import linear_interpolation_with_extrapolation +from tests.utils.interpolation import interpolate_policy_and_value_on_wealth_grid +from tests.utils.interpolation import linear_interpolation_with_extrapolation from tests.utils.upper_envelope_fedor import upper_envelope # Obtain the test directory of the package. @@ -80,9 +80,8 @@ def setup_model(): state_choice_vars = {"lagged_choice": 0, "choice": 0} options["state_space"]["exogenous_states"] = {"exog_state": [0]} - compute_utility = determine_function_arguments_and_partial_options( - utility_crra, options=options - ) + + compute_utility = process_function_args_to_kwargs(utility_crra) return params, exog_savings_grid, state_choice_vars, compute_utility @@ -226,10 +225,10 @@ def test_fast_upper_envelope_against_fedor(period, setup_model): ] ( - endog_grid_calc, - policy_calc_left, - policy_calc_right, - value_calc, + endog_grid_fues, + policy_fues_left, + policy_fues_right, + value_fues, ) = fast_upper_envelope_wrapper( endog_grid=policy_egm[0, 1:], policy=policy_egm[1, 1:], @@ -239,26 +238,23 @@ def test_fast_upper_envelope_against_fedor(period, setup_model): params=params, compute_utility=compute_utility, ) - wealth_max_to_test = np.max(endog_grid_calc[~np.isnan(endog_grid_calc)]) + 100 - wealth_grid_to_test = np.linspace(endog_grid_calc[1], wealth_max_to_test, 1000) + + wealth_max_to_test = np.max(endog_grid_fues[~np.isnan(endog_grid_fues)]) + 100 + wealth_grid_to_test = np.linspace(endog_grid_fues[1], wealth_max_to_test, 1000) value_expec_interp = linear_interpolation_with_extrapolation( x_new=wealth_grid_to_test, x=value_expected[0], y=value_expected[1] ) - policy_expec_interp = linear_interpolation_with_extrapolation( x_new=wealth_grid_to_test, x=policy_expected[0], y=policy_expected[1] ) - ( - policy_calc_interp, - value_calc_interp, - ) = interpolate_policy_and_value_on_wealth_grid( + policy_interp, value_interp = interpolate_policy_and_value_on_wealth_grid( wealth_beginning_of_period=wealth_grid_to_test, - endog_wealth_grid=endog_grid_calc, - policy_left_grid=policy_calc_left, - policy_right_grid=policy_calc_right, - value_grid=value_calc, + endog_wealth_grid=endog_grid_fues, + policy_left_grid=policy_fues_left, + policy_right_grid=policy_fues_right, + value_grid=value_fues, ) - aaae(value_calc_interp, value_expec_interp) - aaae(policy_calc_interp, policy_expec_interp) + aaae(value_interp, value_expec_interp) + aaae(policy_interp, policy_expec_interp) diff --git a/tests/test_upper_envelope_numba.py b/tests/test_upper_envelope_numba.py index 183d7cf..4364247 100644 --- a/tests/test_upper_envelope_numba.py +++ b/tests/test_upper_envelope_numba.py @@ -1,15 +1,16 @@ """Test the numba implementation of the fast upper envelope scan.""" -from functools import partial from pathlib import Path -from typing import Callable import numpy as np import pytest from numpy.testing import assert_array_almost_equal as aaae +from upper_envelope.shared import process_function_args_to_kwargs from upper_envelope.upper_envelope_numba import fast_upper_envelope from upper_envelope.upper_envelope_numba import fast_upper_envelope_wrapper from tests.utils.fast_upper_envelope_org import fast_upper_envelope_wrapper_org +from tests.utils.interpolation import interpolate_single_policy_and_value_on_wealth_grid +from tests.utils.interpolation import linear_interpolation_with_extrapolation from tests.utils.upper_envelope_fedor import upper_envelope # Obtain the test directory of the package. @@ -19,37 +20,6 @@ TEST_RESOURCES_DIR = TEST_DIR / "resources" -def calc_current_value( - consumption: np.ndarray, - next_period_value: np.ndarray, - choice: int, - discount_factor: float, - compute_utility: Callable, -) -> np.ndarray: - """Compute the agent's current value. - - We only support the standard value function, where the current utility and - the discounted next period value have a sum format. - - Args: - consumption (np.ndarray): Level of the agent's consumption. - Array of shape (n_quad_stochastic * n_grid_wealth,). - next_period_value (np.ndarray): The value in the next period. - choice (int): The current discrete choice. - compute_utility (callable): User-defined function to compute the agent's - utility. The input ``params``` is already partialled in. - discount_factor (float): The discount factor. - - Returns: - np.ndarray: The current value. - - """ - utility = compute_utility(consumption, choice) - value = utility + discount_factor * next_period_value - - return value - - def utility_crra(consumption: np.array, choice: int, params: dict) -> np.array: """Computes the agent's current utility based on a CRRA utility function. @@ -87,16 +57,11 @@ def setup_model(): params["rho"] = 1.95 params["delta"] = 0.35 - state_choice_vec = {"choice": 0} + state_choice_vec = {"choice": 0, "lagged_choice": 0} - compute_utility = partial(utility_crra, params=params) - compute_value = partial( - calc_current_value, - discount_factor=params["beta"], - compute_utility=compute_utility, - ) + compute_utility = process_function_args_to_kwargs(utility_crra) - return params, state_choice_vec, exog_savings_grid, compute_utility, compute_value + return params, state_choice_vec, exog_savings_grid, compute_utility @pytest.mark.parametrize("period", [2, 4, 9, 10, 18]) @@ -125,33 +90,44 @@ def test_fast_upper_envelope_wrapper(period, setup_model): ~np.isnan(value_refined_fedor).any(axis=0), ] - ( - _params, - state_choice_vec, - exog_savings_grid, - _compute_utility, - compute_value, - ) = setup_model + params, state_choice_vec, _exog_savings_grid, compute_utility = setup_model endog_grid_refined, policy_refined, value_refined = fast_upper_envelope_wrapper( endog_grid=policy_egm[0, 1:], policy=policy_egm[1, 1:], value=value_egm[1, 1:], expected_value_zero_savings=value_egm[1, 0], - exog_grid=np.append(0, exog_savings_grid), - choice=state_choice_vec["choice"], - compute_value=compute_value, + exog_grid=_exog_savings_grid, + state_choice_vec=state_choice_vec, + params=params, + compute_utility=compute_utility, + ) + + wealth_max_to_test = np.max(endog_grid_refined[~np.isnan(endog_grid_refined)]) + 100 + wealth_grid_to_test = np.linspace( + endog_grid_refined[1], wealth_max_to_test, 1000, dtype=float + ) + + value_expec_interp = linear_interpolation_with_extrapolation( + x_new=wealth_grid_to_test, x=value_expected[0], y=value_expected[1] + ) + + policy_expec_interp = linear_interpolation_with_extrapolation( + x_new=wealth_grid_to_test, x=policy_expected[0], y=policy_expected[1] ) - endog_grid_got = endog_grid_refined[~np.isnan(endog_grid_refined)] - policy_got = policy_refined[~np.isnan(policy_refined)] - value_got = value_refined[~np.isnan(value_refined)] - aaae(endog_grid_got, policy_expected[0]) - aaae(policy_got, policy_expected[1]) - value_expected_interp = np.interp( - endog_grid_got, value_expected[0], value_expected[1] + ( + policy_calc_interp, + value_calc_interp, + ) = interpolate_single_policy_and_value_on_wealth_grid( + wealth_beginning_of_period=wealth_grid_to_test, + endog_wealth_grid=endog_grid_refined, + policy_grid=policy_refined, + value_grid=value_refined, ) - aaae(value_got, value_expected_interp) + + aaae(value_calc_interp, value_expec_interp) + aaae(policy_calc_interp, policy_expec_interp) def test_fast_upper_envelope_against_org_fues(setup_model): @@ -162,13 +138,7 @@ def test_fast_upper_envelope_against_org_fues(setup_model): TEST_RESOURCES_DIR / "upper_envelope_period_tests/val10.csv", delimiter="," ) - ( - _params, - state_choice_vec, - exog_savings_grid, - compute_utility, - _compute_value, - ) = setup_model + _params, state_choice_vec, exog_savings_grid, compute_utility = setup_model endog_grid_refined, value_refined, policy_refined = fast_upper_envelope( endog_grid=policy_egm[0], @@ -206,13 +176,7 @@ def test_fast_upper_envelope_against_fedor(period, setup_model): delimiter=",", ) - ( - params, - state_choice_vec, - exog_savings_grid, - compute_utility, - compute_value, - ) = setup_model + params, state_choice_vec, exog_savings_grid, compute_utility = setup_model _policy_fedor, _value_fedor = upper_envelope( policy=policy_egm, @@ -228,22 +192,32 @@ def test_fast_upper_envelope_against_fedor(period, setup_model): ~np.isnan(_value_fedor).any(axis=0), ] - _endog_grid_fues, _policy_fues, _value_fues = fast_upper_envelope_wrapper( + endog_grid_fues, policy_fues, value_fues = fast_upper_envelope_wrapper( endog_grid=policy_egm[0, 1:], policy=policy_egm[1, 1:], value=value_egm[1, 1:], - expected_value_zero_savings=value_egm[1, 0], exog_grid=np.append(0, exog_savings_grid), - choice=state_choice_vec["choice"], - compute_value=compute_value, + expected_value_zero_savings=value_egm[1, 0], + state_choice_vec=state_choice_vec, + params=params, + compute_utility=compute_utility, + ) + + wealth_max_to_test = np.max(endog_grid_fues[~np.isnan(endog_grid_fues)]) + 100 + wealth_grid_to_test = np.linspace(endog_grid_fues[1], wealth_max_to_test, 1000) + + value_expec_interp = linear_interpolation_with_extrapolation( + x_new=wealth_grid_to_test, x=value_expected[0], y=value_expected[1] + ) + policy_expec_interp = linear_interpolation_with_extrapolation( + x_new=wealth_grid_to_test, x=policy_expected[0], y=policy_expected[1] ) - endog_grid_got = _endog_grid_fues[~np.isnan(_endog_grid_fues)] - policy_got = _policy_fues[~np.isnan(_policy_fues)] - value_got = _value_fues[~np.isnan(_value_fues)] - aaae(endog_grid_got, policy_expected[0]) - aaae(policy_got, policy_expected[1]) - value_expected_interp = np.interp( - endog_grid_got, value_expected[0], value_expected[1] + policy_interp, value_interp = interpolate_single_policy_and_value_on_wealth_grid( + wealth_beginning_of_period=wealth_grid_to_test, + endog_wealth_grid=endog_grid_fues, + policy_grid=policy_fues, + value_grid=value_fues, ) - aaae(value_got, value_expected_interp) + aaae(value_interp, value_expec_interp) + aaae(policy_interp, policy_expec_interp) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/interpolation.py b/tests/utils/interpolation.py new file mode 100644 index 0000000..267ceaa --- /dev/null +++ b/tests/utils/interpolation.py @@ -0,0 +1,222 @@ +import jax.numpy as jnp +import numpy as np + + +def interpolate_policy_and_value_on_wealth_grid( + wealth_beginning_of_period: jnp.ndarray, + endog_wealth_grid: jnp.ndarray, + policy_left_grid: jnp.ndarray, + policy_right_grid: jnp.ndarray, + value_grid: jnp.ndarray, +): + """Interpolate policy and value functions on the wealth grid. + + This function uses the left and right policy function. + For a more detailed description, see calc_intersection_and_extrapolate_policy + in fast_upper_envelope.py. + + Args: + wealth_beginning_of_period (jnp.ndarray): 1d array of shape (n,) containing the + begin of period wealth. + endog_wealth_grid (jnp.array): 1d array of shape (n,) containing the endogenous + wealth grid. + policy_left_grid (jnp.ndarray): 1d array of shape (n,) containing the + left policy function corresponding to the endogenous wealth grid. + policy_right_grid (jnp.ndarray): 1d array of shape (n,) containing the + left policy function corresponding to the endogenous wealth grid. + value_grid (jnp.ndarray): 1d array of shape (n,) containing the value function + values corresponding to the endogenous wealth grid. + + Returns: + tuple: + + - policy_new (jnp.ndarray): 1d array of shape (n,) containing the interpolated + policy function values corresponding to the begin of period wealth. + - value_new (jnp.ndarray): 1d array of shape (n,) containing the interpolated + value function values corresponding to the begin of period wealth. + + """ + ind_high, ind_low = get_index_high_and_low( + x=endog_wealth_grid, x_new=wealth_beginning_of_period + ) + + wealth_low = jnp.take(endog_wealth_grid, ind_low) + wealth_high = jnp.take(endog_wealth_grid, ind_high) + + policy_new = linear_interpolation_formula( + y_high=jnp.take(policy_left_grid, ind_high), + y_low=jnp.take(policy_right_grid, ind_low), + x_high=wealth_high, + x_low=wealth_low, + x_new=wealth_beginning_of_period, + ) + + value_new = linear_interpolation_formula( + y_high=jnp.take(value_grid, ind_high), + y_low=jnp.take(value_grid, ind_low), + x_high=wealth_high, + x_low=wealth_low, + x_new=wealth_beginning_of_period, + ) + + return policy_new, value_new + + +def interpolate_single_policy_and_value_on_wealth_grid( + wealth_beginning_of_period: jnp.ndarray, + endog_wealth_grid: jnp.ndarray, + policy_grid: jnp.ndarray, + value_grid: jnp.ndarray, +): + """Interpolate policy and value functions on the wealth grid. + + This function uses the left and right policy function. + For a more detailed description, see calc_intersection_and_extrapolate_policy + in fast_upper_envelope.py. + + Args: + wealth_beginning_of_period (jnp.ndarray): 1d array of shape (n,) containing the + begin of period wealth. + endog_wealth_grid (jnp.array): 1d array of shape (n,) containing the endogenous + wealth grid. + policy_left_grid (jnp.ndarray): 1d array of shape (n,) containing the + left policy function corresponding to the endogenous wealth grid. + policy_right_grid (jnp.ndarray): 1d array of shape (n,) containing the + left policy function corresponding to the endogenous wealth grid. + value_grid (jnp.ndarray): 1d array of shape (n,) containing the value function + values corresponding to the endogenous wealth grid. + + Returns: + tuple: + + - policy_new (jnp.ndarray): 1d array of shape (n,) containing the interpolated + policy function values corresponding to the begin of period wealth. + - value_new (jnp.ndarray): 1d array of shape (n,) containing the interpolated + value function values corresponding to the begin of period wealth. + + """ + ind_high, ind_low = get_index_high_and_low( + x=endog_wealth_grid, x_new=wealth_beginning_of_period + ) + + wealth_low = jnp.take(endog_wealth_grid, ind_low) + wealth_high = jnp.take(endog_wealth_grid, ind_high) + + policy_new = linear_interpolation_formula( + y_high=jnp.take(policy_grid, ind_high), + y_low=jnp.take(policy_grid, ind_low), + x_high=wealth_high, + x_low=wealth_low, + x_new=wealth_beginning_of_period, + ) + + value_new = linear_interpolation_formula( + y_high=jnp.take(value_grid, ind_high), + y_low=jnp.take(value_grid, ind_low), + x_high=wealth_high, + x_low=wealth_low, + x_new=wealth_beginning_of_period, + ) + + return policy_new, value_new + + +def linear_interpolation_formula( + y_high: float | jnp.ndarray, + y_low: float | jnp.ndarray, + x_high: float | jnp.ndarray, + x_low: float | jnp.ndarray, + x_new: float | jnp.ndarray, +): + """Linear interpolation formula.""" + interpolate_dist = x_new - x_low + interpolate_slope = (y_high - y_low) / (x_high - x_low) + interpol_res = (interpolate_slope * interpolate_dist) + y_low + + return interpol_res + + +def get_index_high_and_low(x, x_new): + """Get index of the highest value in x that is smaller than x_new. + + Args: + x (np.ndarray): 1d array of shape (n,) containing the x-values. + x_new (float): The new x-value at which to evaluate the interpolation function. + + Returns: + int: Index of the value in the wealth grid which is higher than x_new. Or in + case of extrapolation last or first index of not nan element. + + """ + ind_high = jnp.searchsorted(x, x_new).clip(max=(x.shape[0] - 1), min=1) + ind_high -= jnp.isnan(x[ind_high]).astype(int) + return ind_high, ind_high - 1 + + +# ===================================================================================== +# Numpy utils +# ===================================================================================== + + +def linear_interpolation_with_extrapolation(x, y, x_new): + """Linear interpolation with extrapolation. + + Args: + x (np.ndarray): 1d array of shape (n,) containing the x-values. + y (np.ndarray): 1d array of shape (n,) containing the y-values + corresponding to the x-values. + x_new (np.ndarray or float): 1d array of shape (m,) or float containing + the new x-values at which to evaluate the interpolation function. + + Returns: + np.ndarray or float: 1d array of shape (m,) or float containing + the new y-values corresponding to the new x-values. + In case x_new contains values outside of the range of x, these + values are extrapolated. + + """ + # make sure that the function also works for unsorted x-arrays + # taken from scipy.interpolate.interp1d + ind = np.argsort(x, kind="mergesort") + x = x[ind] + y = np.take(y, ind) + + ind_high = np.searchsorted(x, x_new).clip(max=(x.shape[0] - 1), min=1) + ind_low = ind_high - 1 + + y_high = y[ind_high] + y_low = y[ind_low] + x_high = x[ind_high] + x_low = x[ind_low] + + interpolate_dist = x_new - x_low + interpolate_slope = (y_high - y_low) / (x_high - x_low) + interpol_res = (interpolate_slope * interpolate_dist) + y_low + + return interpol_res + + +def linear_interpolation_with_inserting_missing_values(x, y, x_new, missing_value): + """Linear interpolation with inserting missing values. + + Args: + x (np.ndarray): 1d array of shape (n,) containing the x-values. + y (np.ndarray): 1d array of shape (n,) containing the y-values + corresponding to the x-values. + x_new (np.ndarray or float): 1d array of shape (m,) or float containing + the new x-values at which to evaluate the interpolation function. + missing_value (np.ndarray or float): Flat array of shape (1,) or float + to set for values of x_new outside of the range of x. + + Returns: + np.ndarray or float: 1d array of shape (m,) or float containing the + new y-values corresponding to the new x-values. + In case x_new contains values outside of the range of x, these + values are set equal to missing_value. + + """ + interpol_res = linear_interpolation_with_extrapolation(x, y, x_new) + where_to_miss = (x_new < x.min()) | (x_new > x.max()) + interpol_res[where_to_miss] = missing_value + + return interpol_res diff --git a/tests/utils/interpolations.py b/tests/utils/interpolations.py deleted file mode 100644 index 9595819..0000000 --- a/tests/utils/interpolations.py +++ /dev/null @@ -1,64 +0,0 @@ -import numpy as np - - -def linear_interpolation_with_extrapolation(x, y, x_new): - """Linear interpolation with extrapolation. - - Args: - x (np.ndarray): 1d array of shape (n,) containing the x-values. - y (np.ndarray): 1d array of shape (n,) containing the y-values - corresponding to the x-values. - x_new (np.ndarray or float): 1d array of shape (m,) or float containing - the new x-values at which to evaluate the interpolation function. - - Returns: - np.ndarray or float: 1d array of shape (m,) or float containing - the new y-values corresponding to the new x-values. - In case x_new contains values outside of the range of x, these - values are extrapolated. - - """ - # make sure that the function also works for unsorted x-arrays - # taken from scipy.interpolate.interp1d - ind = np.argsort(x, kind="mergesort") - x = x[ind] - y = np.take(y, ind) - - ind_high = np.searchsorted(x, x_new).clip(max=(x.shape[0] - 1), min=1) - ind_low = ind_high - 1 - - y_high = y[ind_high] - y_low = y[ind_low] - x_high = x[ind_high] - x_low = x[ind_low] - - interpolate_dist = x_new - x_low - interpolate_slope = (y_high - y_low) / (x_high - x_low) - interpol_res = (interpolate_slope * interpolate_dist) + y_low - - return interpol_res - - -def linear_interpolation_with_inserting_missing_values(x, y, x_new, missing_value): - """Linear interpolation with inserting missing values. - - Args: - x (np.ndarray): 1d array of shape (n,) containing the x-values. - y (np.ndarray): 1d array of shape (n,) containing the y-values - corresponding to the x-values. - x_new (np.ndarray or float): 1d array of shape (m,) or float containing - the new x-values at which to evaluate the interpolation function. - missing_value (np.ndarray or float): Flat array of shape (1,) or float - to set for values of x_new outside of the range of x. - - Returns: - np.ndarray or float: 1d array of shape (m,) or float containing the - new y-values corresponding to the new x-values. - In case x_new contains values outside of the range of x, these - values are set equal to missing_value. - - """ - interpol_res = linear_interpolation_with_extrapolation(x, y, x_new) - where_to_miss = (x_new < x.min()) | (x_new > x.max()) - interpol_res[where_to_miss] = missing_value - return interpol_res diff --git a/tests/utils/upper_envelope_fedor.py b/tests/utils/upper_envelope_fedor.py index 9e5de59..0676125 100644 --- a/tests/utils/upper_envelope_fedor.py +++ b/tests/utils/upper_envelope_fedor.py @@ -12,12 +12,8 @@ import numpy as np from scipy.optimize import brenth as root -from tests.utils.interpolations import linear_interpolation_with_extrapolation -from tests.utils.interpolations import ( - linear_interpolation_with_inserting_missing_values, -) -eps = 2.2204e-16 +EPS = 2.2204e-16 def upper_envelope( @@ -247,7 +243,7 @@ def compute_upper_envelope( values_interp = np.empty((len(segments), len(endog_wealth_grid))) for i, segment in enumerate(segments): - values_interp[i, :] = linear_interpolation_with_inserting_missing_values( + values_interp[i, :] = _linear_interpolation_with_inserting_missing_values( x=segment[0], y=segment[1], x_new=endog_wealth_grid, @@ -280,7 +276,7 @@ def compute_upper_envelope( second_grid_point = endog_wealth_grid[i] values_first_segment = ( - linear_interpolation_with_inserting_missing_values( + _linear_interpolation_with_inserting_missing_values( x=segments[first_segment][0], y=segments[first_segment][1], x_new=np.array([first_grid_point, second_grid_point]), @@ -288,7 +284,7 @@ def compute_upper_envelope( ) ) values_second_segment = ( - linear_interpolation_with_inserting_missing_values( + _linear_interpolation_with_inserting_missing_values( x=segments[second_segment][0], y=segments[second_segment][1], x_new=np.array([first_grid_point, second_grid_point]), @@ -311,7 +307,7 @@ def compute_upper_envelope( ), ) value_intersect = ( - linear_interpolation_with_inserting_missing_values( + _linear_interpolation_with_inserting_missing_values( x=segments[first_segment][0], y=segments[first_segment][1], x_new=np.array([intersect_point]), @@ -323,7 +319,7 @@ def compute_upper_envelope( for segment in range(len(segments)): values_all_segments[ segment - ] = linear_interpolation_with_inserting_missing_values( + ] = _linear_interpolation_with_inserting_missing_values( x=segments[segment][0], y=segments[segment][1], x_new=np.array([intersect_point]), @@ -351,7 +347,7 @@ def compute_upper_envelope( # Add point if it lies currently on the highest segment if ( - any(abs(segments[index_second_segment][0] - endog_wealth_grid[i]) < eps) + any(abs(segments[index_second_segment][0] - endog_wealth_grid[i]) < EPS) is True ): grid_points_upper_env.append(endog_wealth_grid[i]) @@ -461,7 +457,7 @@ def refine_policy( ) # Find (scalar) point interpolated from the left - interp_from_the_left = linear_interpolation_with_extrapolation( + interp_from_the_left = _linear_interpolation_with_extrapolation( x=policy[0, :][last_point_to_the_left : last_point_to_the_left + 2], y=policy[1, :][last_point_to_the_left : last_point_to_the_left + 2], x_new=points_to_add[0][new_grid_point], @@ -474,7 +470,7 @@ def refine_policy( ) # Find (scalar) point interpolated from the right - interp_from_the_right = linear_interpolation_with_extrapolation( + interp_from_the_right = _linear_interpolation_with_extrapolation( x=policy[0, :][first_point_to_the_right - 1 : first_point_to_the_right + 1], y=policy[1, :][first_point_to_the_right - 1 : first_point_to_the_right + 1], x_new=points_to_add[0, new_grid_point], @@ -641,13 +637,77 @@ def _partition_grid( def _subtract_values(grid_point: float, first_segment, second_segment): """Subtracts the interpolated values of the two uppermost segments.""" - values_first_segment = linear_interpolation_with_extrapolation( + values_first_segment = _linear_interpolation_with_extrapolation( x=first_segment[0], y=first_segment[1], x_new=grid_point ) - values_second_segment = linear_interpolation_with_extrapolation( + values_second_segment = _linear_interpolation_with_extrapolation( x=second_segment[0], y=second_segment[1], x_new=grid_point ) diff_values_segments = values_first_segment - values_second_segment return diff_values_segments + + +def _linear_interpolation_with_extrapolation(x, y, x_new): + """Linear interpolation with extrapolation. + + Args: + x (np.ndarray): 1d array of shape (n,) containing the x-values. + y (np.ndarray): 1d array of shape (n,) containing the y-values + corresponding to the x-values. + x_new (np.ndarray or float): 1d array of shape (m,) or float containing + the new x-values at which to evaluate the interpolation function. + + Returns: + np.ndarray or float: 1d array of shape (m,) or float containing + the new y-values corresponding to the new x-values. + In case x_new contains values outside of the range of x, these + values are extrapolated. + + """ + # make sure that the function also works for unsorted x-arrays + # taken from scipy.interpolate.interp1d + ind = np.argsort(x, kind="mergesort") + x = x[ind] + y = np.take(y, ind) + + ind_high = np.searchsorted(x, x_new).clip(max=(x.shape[0] - 1), min=1) + ind_low = ind_high - 1 + + y_high = y[ind_high] + y_low = y[ind_low] + x_high = x[ind_high] + x_low = x[ind_low] + + interpolate_dist = x_new - x_low + interpolate_slope = (y_high - y_low) / (x_high - x_low) + interpol_res = (interpolate_slope * interpolate_dist) + y_low + + return interpol_res + + +def _linear_interpolation_with_inserting_missing_values(x, y, x_new, missing_value): + """Linear interpolation with inserting missing values. + + Args: + x (np.ndarray): 1d array of shape (n,) containing the x-values. + y (np.ndarray): 1d array of shape (n,) containing the y-values + corresponding to the x-values. + x_new (np.ndarray or float): 1d array of shape (m,) or float containing + the new x-values at which to evaluate the interpolation function. + missing_value (np.ndarray or float): Flat array of shape (1,) or float + to set for values of x_new outside of the range of x. + + Returns: + np.ndarray or float: 1d array of shape (m,) or float containing the + new y-values corresponding to the new x-values. + In case x_new contains values outside of the range of x, these + values are set equal to missing_value. + + """ + interpol_res = _linear_interpolation_with_extrapolation(x, y, x_new) + where_to_miss = (x_new < x.min()) | (x_new > x.max()) + interpol_res[where_to_miss] = missing_value + + return interpol_res