Skip to content

Commit

Permalink
Removed exog grid from numba inputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxBlesch committed Jun 3, 2024
1 parent debff28 commit 2683232
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 14 deletions.
10 changes: 1 addition & 9 deletions src/upper_envelope/fues_numba/fues_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def fues_numba(
endog_grid: np.ndarray,
policy: np.ndarray,
value: np.ndarray,
exog_grid: np.ndarray,
expected_value_zero_savings: np.ndarray | float,
value_function: Callable,
value_function_args: Tuple,
Expand Down Expand Up @@ -57,8 +56,6 @@ def fues_numba(
containing the current state- and choice-specific policy function.
value (np.ndarray): 1d array of shape (n_grid_wealth + 1,)
containing the current state- and choice-specific value function.
exog_grid (np.ndarray): 1d array of shape (n_grid_wealth,) of the
exogenous savings grid.
expected_value_zero_savings (float): The agent's expected value given that she
saves zero.
Expand Down Expand Up @@ -98,18 +95,15 @@ def fues_numba(
value_function=value_function,
value_function_args=value_function_args,
)
exog_grid = np.append(np.zeros(n_constrained_points_to_add), exog_grid)

endog_grid = np.append(0, endog_grid)
policy = np.append(0, policy)
value = np.append(expected_value_zero_savings, value)
exog_grid = np.append(0, exog_grid)

endog_grid_refined, value_refined, policy_refined = fues_numba_unconstrained(
endog_grid,
value,
policy,
exog_grid,
jump_thresh=jump_thresh,
n_points_to_scan=n_points_to_scan,
)
Expand Down Expand Up @@ -143,7 +137,6 @@ def fues_numba_unconstrained(
endog_grid: np.ndarray,
value: np.ndarray,
policy: np.ndarray,
exog_grid: np.ndarray,
jump_thresh=2,
n_points_to_scan=10,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -176,14 +169,13 @@ def fues_numba_unconstrained(

endog_grid = endog_grid[np.where(~np.isnan(value))[0]]
policy = policy[np.where(~np.isnan(value))]
exog_grid = exog_grid[np.where(~np.isnan(value))[0]]
value = value[np.where(~np.isnan(value))]

idx_sort = np.argsort(endog_grid, kind="mergesort")
value = np.take(value, idx_sort)
policy = np.take(policy, idx_sort)
exog_grid = np.take(exog_grid, idx_sort)
endog_grid = np.take(endog_grid, idx_sort)
exog_grid = endog_grid - policy

(
value_clean_with_nans,
Expand Down
1 change: 0 additions & 1 deletion tests/test_fues_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def test_fast_upper_envelope_against_numba(setup_model):
endog_grid=policy_egm[0],
value=value_egm[1],
policy=policy_egm[1],
exog_grid=np.append(0, exog_savings_grid),
)

(
Expand Down
3 changes: 0 additions & 3 deletions tests/test_fues_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def value_func(consumption, choice, params):
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
expected_value_zero_savings=value_egm[1, 0],
exog_grid=_exog_savings_grid,
value_function=value_func,
value_function_args=(state_choice_vec["choice"], params),
)
Expand Down Expand Up @@ -144,7 +143,6 @@ def test_fast_upper_envelope_against_org_fues(setup_model):
endog_grid=policy_egm[0],
value=value_egm[1],
policy=policy_egm[1],
exog_grid=np.append(0, exog_savings_grid),
)

endog_grid_org, policy_org, value_org = fast_upper_envelope_wrapper_org(
Expand Down Expand Up @@ -201,7 +199,6 @@ def value_func(consumption, choice, params):
endog_grid=policy_egm[0, 1:],
policy=policy_egm[1, 1:],
value=value_egm[1, 1:],
exog_grid=np.append(0, exog_savings_grid),
expected_value_zero_savings=value_egm[1, 0],
value_function=value_func,
value_function_args=(state_choice_vec["choice"], params),
Expand Down
14 changes: 13 additions & 1 deletion tests/utils/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def interpolate_policy_and_value_on_wealth_grid(
value function values corresponding to the begin of period wealth.
"""
# Make sure that these are numpy arrays
wealth_beginning_of_period = np.asarray(wealth_beginning_of_period)
endog_wealth_grid = np.asarray(endog_wealth_grid)
policy_grid = np.asarray(policy_grid)
value_function_grid = np.asarray(value_function_grid)
ind_high, ind_low = get_index_high_and_low(
x=endog_wealth_grid, x_new=wealth_beginning_of_period
)
Expand Down Expand Up @@ -87,6 +92,11 @@ def interpolate_single_policy_and_value_on_wealth_grid(
value function values corresponding to the begin of period wealth.
"""
# Make sure that these are numpy arrays
wealth_beginning_of_period = np.asarray(wealth_beginning_of_period)
endog_wealth_grid = np.asarray(endog_wealth_grid)
policy_grid = np.asarray(policy_grid)
value_function_grid = np.asarray(value_function_grid)
ind_high, ind_low = get_index_high_and_low(
x=endog_wealth_grid, x_new=wealth_beginning_of_period
)
Expand Down Expand Up @@ -128,7 +138,9 @@ def linear_interpolation_formula(
return interpol_res


def get_index_high_and_low(x: np.ndarray, x_new: np.ndarray | float) -> Tuple[int, int]:
def get_index_high_and_low(
x: np.ndarray, x_new: np.ndarray | float
) -> Tuple[np.ndarray, np.ndarray]:
"""Get index of the highest value in x that is smaller than x_new.
Args:
Expand Down

0 comments on commit 2683232

Please sign in to comment.