diff --git a/.envs/testenv.yml b/.envs/testenv.yml index 2317018..9e36df6 100644 --- a/.envs/testenv.yml +++ b/.envs/testenv.yml @@ -1,10 +1,8 @@ --- name: upper-envelope - channels: - conda-forge - nodefaults - dependencies: - pip - setuptools_scm @@ -24,4 +22,4 @@ dependencies: # Install locally - pip: - - -e ../ \ No newline at end of file + - -e ../ diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6cdd13f..acfc3f6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -45,4 +45,4 @@ jobs: if: runner.os == 'Linux' && matrix.python-version == '3.10' uses: codecov/codecov-action@v3 with: - token: ${{ secrets.CODECOV_TOKEN }} \ No newline at end of file + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.yamllint.yml b/.yamllint.yml index 83dc25a..72f64be 100644 --- a/.yamllint.yml +++ b/.yamllint.yml @@ -33,4 +33,4 @@ rules: quoted-strings: disable trailing-spaces: enable truthy: - level: warning \ No newline at end of file + level: warning diff --git a/README.md b/README.md index 672b7ed..56469da 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,9 @@ [![Codecov](https://codecov.io/gh/OpenSourceEconomics/upper-envelope/branch/main/graph/badge.svg)](https://app.codecov.io/gh/OpenSourceEconomics/upper-envelope) [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -Extension of the Fast Upper-Envelope Scan (FUES) for solving discrete-continuous dynamic programming problems based on Dobrescu & Shanker (2022). Both -`jax` and `numba` versions are available. +Extension of the Fast Upper-Envelope Scan (FUES) for solving discrete-continuous dynamic +programming problems based on Dobrescu & Shanker (2022). Both `jax` and `numba` versions +are available. ## References @@ -13,6 +14,5 @@ Extension of the Fast Upper-Envelope Scan (FUES) for solving discrete-continuous [The Endogenous Grid Method for Discrete-Continuous Dynamic Choice Models with (or without) Taste Shocks](http://onlinelibrary.wiley.com/doi/10.3982/QE643/full). *Quantitative Economics* - 1. Loretti I. Dobrescu & Akshay Shanker (2022). - [Fast Upper-Envelope Scan for Discrete-Continuous Dynamic Programming](https://dx.doi.org/10.2139/ssrn.4181302). \ No newline at end of file + [Fast Upper-Envelope Scan for Discrete-Continuous Dynamic Programming](https://dx.doi.org/10.2139/ssrn.4181302). diff --git a/codecov.yml b/codecov.yml index db82677..78d43ec 100644 --- a/codecov.yml +++ b/codecov.yml @@ -9,7 +9,7 @@ coverage: status: patch: default: - target: 70% + target: 80% project: default: target: 90% diff --git a/docs/tutorials/getting_started.ipynb b/docs/tutorials/getting_started.ipynb new file mode 100644 index 0000000..f343212 --- /dev/null +++ b/docs/tutorials/getting_started.ipynb @@ -0,0 +1,192 @@ +{ + "cells": [ + { + "cell_type": "code", + "outputs": [], + "source": [ + "import numpy as np\n", + "from upper_envelope.fues_numba.fues_numba import fast_upper_envelope_wrapper\n", + "import numba as nb\n", + "from collections import namedtuple\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-24T06:42:04.589484405Z", + "start_time": "2024-05-24T06:42:04.167227740Z" + } + }, + "id": "c71bb4b54fd1da58", + "execution_count": 1 + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "@nb.njit\n", + "def utility_crra(consumption: np.array, params_ntuple) -> np.array:\n", + " \"\"\"Computes the agent's current utility based on a CRRA utility function.\n", + "\n", + " Args:\n", + " consumption (jnp.array): Level of the agent's consumption.\n", + " Array of shape (i) (n_quad_stochastic * n_grid_wealth,)\n", + " when called by :func:`~dcgm.call_egm_step.map_exog_to_endog_grid`\n", + " and :func:`~dcgm.call_egm_step.get_next_period_value`, or\n", + " (ii) of shape (n_grid_wealth,) when called by\n", + " :func:`~dcgm.call_egm_step.get_current_period_value`.\n", + " choice (int): Choice of the agent, e.g. 0 = \"retirement\", 1 = \"working\".\n", + " params_dict (dict): Dictionary containing model parameters.\n", + " Relevant here is the CRRA coefficient theta.\n", + "\n", + " Returns:\n", + " utility (jnp.array): Agent's utility . Array of shape\n", + " (n_quad_stochastic * n_grid_wealth,) or (n_grid_wealth,).\n", + "\n", + " \"\"\"\n", + " utility_consumption = (consumption ** (1 - params_ntuple.rho) - 1) / (\n", + " 1 - params_ntuple.rho\n", + " )\n", + "\n", + " utility = utility_consumption - (1 - params_ntuple.choice) * params_ntuple.delta\n", + "\n", + " return utility" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-24T06:42:04.611236211Z", + "start_time": "2024-05-24T06:42:04.609614665Z" + } + }, + "id": "ab4b0dcb970dac41", + "execution_count": 2 + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-05-24T06:42:04.613370376Z", + "start_time": "2024-05-24T06:42:04.609954700Z" + } + }, + "outputs": [], + "source": [ + "max_wealth = 50\n", + "n_grid_wealth = 500\n", + "exog_savings_grid = np.linspace(0, max_wealth, n_grid_wealth)\n", + "\n", + "beta = 0.95 # discount_factor\n", + "\n", + "utility_kwargs = {\n", + " \"choice\": 0,\n", + " \"rho\": 1.95,\n", + " \"delta\": 0.35,\n", + "}" + ] + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "resource_dir = \"../../tests/resources/\"\n", + "value_egm = np.genfromtxt(\n", + " resource_dir + \"upper_envelope_period_tests/val2.csv\",\n", + " delimiter=\",\",\n", + ")\n", + "policy_egm = np.genfromtxt(\n", + " resource_dir + \"upper_envelope_period_tests/pol2.csv\",\n", + " delimiter=\",\",\n", + ")\n", + "n_constrained_points_to_add = int(0.1 * len(policy_egm[0]))\n", + "n_final_wealth_grid = int(1.2 * (len(policy_egm[0])))\n", + "tuning_params = {\n", + " \"n_final_wealth_grid\": n_final_wealth_grid,\n", + " \"jump_thresh\": 2,\n", + " \"n_constrained_points_to_add\": n_constrained_points_to_add,\n", + " \"n_points_to_scan\": 10,\n", + "}" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-24T06:42:04.614201145Z", + "start_time": "2024-05-24T06:42:04.610074325Z" + } + }, + "id": "61b7b971c6f81237", + "execution_count": 4 + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "\n", + "utility_ntuple = namedtuple(\"utility_params\", utility_kwargs.keys())(\n", + " *utility_kwargs.values()\n", + ")\n", + "tuning_params_tuple = namedtuple(\"tunings\", tuning_params.keys())(\n", + " *tuning_params.values()\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-24T06:42:04.614607376Z", + "start_time": "2024-05-24T06:42:04.610206856Z" + } + }, + "id": "415a872b98d65631", + "execution_count": 5 + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "endog_grid_refined, policy_refined, value_refined = fast_upper_envelope_wrapper(\n", + " endog_grid=policy_egm[0, 1:],\n", + " policy=policy_egm[1, 1:],\n", + " value=value_egm[1, 1:],\n", + " expected_value_zero_savings=value_egm[1, 0],\n", + " exog_grid=exog_savings_grid,\n", + " utility_function=utility_crra,\n", + " utility_kwargs=utility_ntuple,\n", + " discount_factor=beta,\n", + " tuning_params=tuning_params_tuple,\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-24T06:42:14.783923576Z", + "start_time": "2024-05-24T06:42:04.611563430Z" + } + }, + "id": "df76ba014c57934d", + "execution_count": 6 + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/environment.yml b/environment.yml index 41b3182..0a57a9a 100644 --- a/environment.yml +++ b/environment.yml @@ -1,10 +1,8 @@ --- name: upper-envelope - channels: - conda-forge - defaults - dependencies: - python=3.10 - pip @@ -37,8 +35,7 @@ dependencies: - conda-build - conda-verify - tox-conda - - pip: - blackcellmagic - furo - - -e . # Install locally \ No newline at end of file + - -e . # Install locally diff --git a/pyproject.toml b/pyproject.toml index 1bf4a32..70d1386 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ write_to = "src/upper_envelope/_version.py" [tool.ruff] target-version = "py310" fix = true +ignore = ["F401"] [tool.yamlfix] line_length = 88 diff --git a/setup.py b/setup.py index 57c026b..7f1a176 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ from setuptools import setup if __name__ == "__main__": - setup() \ No newline at end of file + setup() diff --git a/src/upper_envelope/__init__.py b/src/upper_envelope/__init__.py index e69de29..0ae40f1 100644 --- a/src/upper_envelope/__init__.py +++ b/src/upper_envelope/__init__.py @@ -0,0 +1,4 @@ +from upper_envelope.fues_jax.fues_jax import fues_jax +from upper_envelope.fues_jax.fues_jax import fues_jax_unconstrained +from upper_envelope.fues_numba.fues_numba import fues_numba +from upper_envelope.fues_numba.fues_numba import fues_numba_unconstrained diff --git a/src/upper_envelope/fues_jax/fues_jax.py b/src/upper_envelope/fues_jax/fues_jax.py index a116a64..927aabc 100644 --- a/src/upper_envelope/fues_jax/fues_jax.py +++ b/src/upper_envelope/fues_jax/fues_jax.py @@ -23,14 +23,28 @@ ) -def fast_upper_envelope_wrapper( +@partial( + jax.jit, + static_argnames=[ + "value_function", + "n_constrained_points_to_add", + "n_final_wealth_grid", + "jump_thresh", + "n_points_to_scan", + ], +) +def fues_jax( endog_grid: jnp.ndarray, policy: jnp.ndarray, value: jnp.ndarray, - expected_value_zero_savings: float, - utility_function: Callable, - utility_kwargs: Dict, - disc_factor: float, + expected_value_zero_savings: jnp.ndarray | float, + value_function: Callable, + value_function_args: Optional[Tuple] = (), + value_function_kwargs: Optional[Dict] = {}, + n_constrained_points_to_add=None, + n_final_wealth_grid=None, + jump_thresh=2, + n_points_to_scan=10, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Drop suboptimal points and refines the endogenous grid, policy, and value. @@ -56,59 +70,71 @@ def fast_upper_envelope_wrapper( subsequent periods t + 1, t + 2, ..., T under the optimal consumption policy. Args: - endog_grid (np.ndarray): 1d array of shape (n_grid_wealth + 1,) + endog_grid (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,) containing the current state- and choice-specific endogenous grid. - policy (np.ndarray): 1d array of shape (n_grid_wealth + 1,) + policy (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,) containing the current state- and choice-specific policy function. - value (np.ndarray): 1d array of shape (n_grid_wealth + 1,) + value (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,) containing the current state- and choice-specific value function. - expected_value_zero_savings (float): The agent's expected value given that she - saves zero. - utility_function (callable): The utility function. The first argument is - assumed to be consumption. - utility_kwargs (dict): The keyword arguments to be passed to the utility + expected_value_zero_savings (jnp.ndarray | float): The agent's expected value + given that she saves zero. + value_function (callable): The value function for calculating the value if + nothing is saved. + value_function_args (Tuple): The positional arguments to be passed to the value function. + value_function_kwargs (dict): The keyword arguments to be passed to the value + function. + n_constrained_points_to_add (int): Number of constrained points to add to the + left of the first grid point if there is an area with credit-constrain. + n_final_wealth_grid (int): Size of final function grid. Determines number of + iterations for the scan in the fues_jax. + jump_thresh (float): Jump detection threshold. + n_points_to_scan (int): Number of points to scan for suboptimal points. Returns: tuple: - - endog_grid_refined (np.ndarray): 1d array of shape (1.1 * n_grid_wealth,) - containing the refined state- and choice-specific endogenous grid. - - policy_refined_with_nans (np.ndarray): 1d array of shape (1.1 * n_grid_wealth) - containing refined state- and choice-specificconsumption policy. - - value_refined_with_nans (np.ndarray): 1d array of shape (1.1 * n_grid_wealth) - containing refined state- and choice-specific value function. + - endog_grid_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,) + containing the refined endogenous wealth grid. + - policy_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,) + containing refined consumption policy. + - value_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,) + containing refined value function. """ + # Set default of n_constrained_points_to_add to 10% of the grid size + n_constrained_points_to_add = ( + endog_grid.shape[0] // 10 + if n_constrained_points_to_add is None + else n_constrained_points_to_add + ) + + # Check if a non-concave region coincides with the credit constrained region. + # This happens when there is a non-monotonicity in the endogenous wealth grid + # that goes below the first point (the minimal wealth, below it is optimal to + # consume everything). + + # If there is such a non-concave region, we extend the value function to the left + # of the first point and calculate the value function there with the supplied value + # function. + + # Because of jax, we always need to perform the same set of computations. Hence, + # if there is no wealth grid point below the first, we just add nans thereafter. min_id = np.argmin(endog_grid) min_wealth_grid = endog_grid[min_id] - # These tuning parameters should be set outside. Don't want to touch solve.py now - points_to_add = len(endog_grid) // 10 - num_iter = int(1.2 * value.shape[0]) - jump_thresh = 2 - # Non-concave region coincides with credit constraint. - # This happens when there is a non-monotonicity in the endogenous wealth grid - # that goes below the first point. - # Solution: Value function to the left of the first point is analytical, - # so we just need to add some points to the left of the first grid point. - # We do that independent of whether the condition is fulfilled or not. - # If the condition is not fulfilled this is points_to_add times the same point. # This is the condition, which we do not use at the moment. # closed_form_cond = min_wealth_grid < endog_grid[0] - grid_points_to_add = jnp.linspace(min_wealth_grid, endog_grid[0], points_to_add)[ - :-1 - ] + grid_points_to_add = jnp.linspace( + min_wealth_grid, endog_grid[0], n_constrained_points_to_add + 1 + )[:-1] # Compute closed form values - values_to_add = vmap(_compute_value, in_axes=(0, None, None, None, None))( - grid_points_to_add, - expected_value_zero_savings, - utility_function, - utility_kwargs, - disc_factor, + values_to_add = vmap(_compute_value, in_axes=(0, None, None, None))( + grid_points_to_add, value_function, value_function_args, value_function_kwargs ) - # Now determine if we actually had to extend the grid. If not, we just add nans. + # Now determine if we actually had to extend the grid. + # If not, we just add nans. no_need_to_add = min_id == 0 multiplikator = jax.lax.select(no_need_to_add, jnp.nan, 1.0) grid_points_to_add *= multiplikator @@ -122,13 +148,14 @@ def fast_upper_envelope_wrapper( endog_grid_refined, value_refined, policy_refined, - ) = fast_upper_envelope( + ) = fues_jax_unconstrained( grid_augmented, value_augmented, policy_augmented, expected_value_zero_savings, - num_iter=num_iter, + n_final_wealth_grid=n_final_wealth_grid, jump_thresh=jump_thresh, + n_points_to_scan=n_points_to_scan, ) return ( endog_grid_refined, @@ -137,50 +164,49 @@ def fast_upper_envelope_wrapper( ) -def fast_upper_envelope( +@partial( + jax.jit, static_argnames=["n_final_wealth_grid", "jump_thresh", "n_points_to_scan"] +) +def fues_jax_unconstrained( endog_grid: jnp.ndarray, value: jnp.ndarray, policy: jnp.ndarray, expected_value_zero_savings: float, - num_iter: int, - jump_thresh: Optional[float] = 2, + n_final_wealth_grid=None, + jump_thresh=2, + n_points_to_scan=10, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Remove suboptimal points from the endogenous grid, policy, and value function. Args: - endog_grid (np.ndarray): 1d array containing the unrefined endogenous wealth - grid of shape (n_grid_wealth + 1,). - value (np.ndarray): 1d array containing the unrefined value correspondence - of shape (n_grid_wealth + 1,). - policy (np.ndarray): 1d array containing the unrefined policy correspondence - of shape (n_grid_wealth + 1,). - expected_value_zero_savings (float): The agent's expected value given that she - saves zero. - num_iter (int): Number of iterations to execute the fues. Recommended to use - twenty percent more than the actual array size. + endog_grid (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,) + containing the current state- and choice-specific endogenous grid. + policy (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,) + containing the current state- and choice-specific policy function. + value (jnp.ndarray): 1d array of shape (n_grid_wealth + 1,) + containing the current state- and choice-specific value function. + expected_value_zero_savings (jnp.ndarray | float): The agent's expected value + given that she saves zero. + n_final_wealth_grid (int): Size of final function grid. Determines number of + iterations for the scan in the fues_jax. jump_thresh (float): Jump detection threshold. + n_points_to_scan (int): Number of points to scan for suboptimal points. Returns: tuple: - - endog_grid_refined (np.ndarray): 1d array containing the refined endogenous - wealth grid of shape (n_grid_clean,), which maps only to the optimal points - in the value function. - - value_refined (np.ndarray): 1d array containing the refined value function - of shape (n_grid_clean,). Overlapping segments have been removed and only - the optimal points are kept. - - policy_refined (np.ndarray): 1d array containing the refined policy function - of shape (n_grid_clean,). Overlapping segments have been removed and only - the optimal points are kept. + - endog_grid_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,) + containing the refined endogenous wealth grid. + - policy_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,) + containing refined consumption policy. + - value_refined (jnp.ndarray): 1d array of shape (n_final_wealth_grid,) + containing refined value function. """ - # Comment by Akshay: Determine locations where endogenous grid points are - # equal to the lower bound. Not relevant for us. - # mask = endog_grid <= lower_bound_wealth - # if jnp.any(mask): - # max_value_lower_bound = jnp.nanmax(value[mask]) - # mask &= value < max_value_lower_bound - # value[mask] = jnp.nan + # Set default value of final grid size to 1.2 times current if not defined + n_final_wealth_grid = ( + int(1.2 * (len(policy))) if n_final_wealth_grid is None else n_final_wealth_grid + ) idx_sort = jnp.argsort(endog_grid) value = jnp.take(value, idx_sort) @@ -196,9 +222,9 @@ def fast_upper_envelope( value=value, policy=policy, expected_value_zero_savings=expected_value_zero_savings, - num_iter=num_iter, + n_final_wealth_grid=n_final_wealth_grid, jump_thresh=jump_thresh, - n_points_to_scan=10, + n_points_to_scan=n_points_to_scan, ) return endog_grid_refined, value_refined, policy_refined @@ -209,7 +235,7 @@ def scan_value_function( value: jnp.ndarray, policy: jnp.ndarray, expected_value_zero_savings, - num_iter: int, + n_final_wealth_grid: int, jump_thresh: float, n_points_to_scan: Optional[int] = 0, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: @@ -224,8 +250,8 @@ def scan_value_function( of shape (n_grid_wealth + 1,). expected_value_zero_savings (float): The agent's expected value given that she saves zero. - num_iter (int): Number of iterations to execute the fues. Recommended to use - twenty percent more than the actual array size. + n_final_wealth_grid (int): Size of final grid. Determines number of + iterations for the scan in the fues_jax. jump_thresh (float): Jump detection threshold. n_points_to_scan (int): Number of points to scan for suboptimal points. @@ -281,7 +307,7 @@ def scan_value_function( partial_body, carry_init, xs=None, - length=num_iter, + length=n_final_wealth_grid, ) result_arrays, sort_index = result value, policy, endog_grid = result_arrays @@ -794,10 +820,11 @@ def select_and_calculate_intersection( def _compute_value( - consumption, next_period_value, utility_function, utility_kwargs, discount_factor + consumption, value_function, value_function_args, value_function_kwargs ): - utility = utility_function( - consumption=consumption, - **utility_kwargs, + value = value_function( + consumption, + *value_function_args, + **value_function_kwargs, ) - return utility + discount_factor * next_period_value + return value diff --git a/src/upper_envelope/fues_numba/fues_numba.py b/src/upper_envelope/fues_numba/fues_numba.py index 186cf52..92ad6a7 100644 --- a/src/upper_envelope/fues_numba/fues_numba.py +++ b/src/upper_envelope/fues_numba/fues_numba.py @@ -6,7 +6,6 @@ """ from typing import Callable -from typing import Dict from typing import Optional from typing import Tuple @@ -14,15 +13,18 @@ from numba import njit -def fast_upper_envelope_wrapper( +@njit +def fues_numba( endog_grid: np.ndarray, policy: np.ndarray, value: np.ndarray, - exog_grid: np.ndarray, - expected_value_zero_savings: float, - utility_function: Callable, - utility_kwargs: Dict, - discount_factor: float, + expected_value_zero_savings: np.ndarray | float, + value_function: Callable, + value_function_args: Tuple, + n_constrained_points_to_add=None, + n_final_wealth_grid=None, + jump_thresh=2, + n_points_to_scan=10, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Drop suboptimal points and refine the endogenous grid, policy, and value. @@ -54,11 +56,17 @@ def fast_upper_envelope_wrapper( containing the current state- and choice-specific policy function. value (np.ndarray): 1d array of shape (n_grid_wealth + 1,) containing the current state- and choice-specific value function. - exog_grid (np.ndarray): 1d array of shape (n_grid_wealth,) of the - exogenous savings grid. - expected_value_zero_savings (float): The agent's expected value given that she - saves zero. - + expected_value_zero_savings (np.ndarray | float): The agent's expected value + given that she saves zero. + value_function (callable): The value function for calculating the value if + nothing is saved. + value_function_args (Tuple): The positional arguments to be passed to the value + function. + n_constrained_points_to_add (int): Number of constrained points to add to the + left of the first grid point if there is an area with credit-constrain. + n_final_wealth_grid (int): Size of final function grid. + jump_thresh (float): Jump detection threshold. + n_points_to_scan (int): Number of points to scan for suboptimal points. Returns: tuple: @@ -70,11 +78,7 @@ def fast_upper_envelope_wrapper( containing refined state- and choice-specific value function. """ - n_grid_wealth = len(endog_grid) min_wealth_grid = np.min(endog_grid) - # exog_grid = np.append( - # 0, np.linspace(min_wealth_grid, endog_grid[-1], n_grid_wealth - 1) - # ) if endog_grid[0] > min_wealth_grid: # Non-concave region coincides with credit constraint. @@ -83,32 +87,44 @@ def fast_upper_envelope_wrapper( # Solution: Value function to the left of the first point is analytical, # so we just need to add some points to the left of the first grid point. + # Set default of n_constrained_points_to_add to 10% of the grid size + n_constrained_points_to_add = ( + endog_grid.shape[0] // 10 + if n_constrained_points_to_add is None + else n_constrained_points_to_add + ) + endog_grid, value, policy = _augment_grids( endog_grid=endog_grid, value=value, policy=policy, - expected_value_zero_savings=expected_value_zero_savings, min_wealth_grid=min_wealth_grid, - n_grid_wealth=n_grid_wealth, - utility_function=utility_function, - utility_kwargs=utility_kwargs, - discount_factor=discount_factor, + n_constrained_points_to_add=n_constrained_points_to_add, + value_function=value_function, + value_function_args=value_function_args, ) - exog_grid = np.append(np.zeros(n_grid_wealth // 10 - 1), exog_grid) endog_grid = np.append(0, endog_grid) policy = np.append(0, policy) value = np.append(expected_value_zero_savings, value) - exog_grid = np.append(0, exog_grid) - endog_grid_refined, value_refined, policy_refined = fast_upper_envelope( - endog_grid, value, policy, exog_grid, jump_thresh=2 + endog_grid_refined, value_refined, policy_refined = fues_numba_unconstrained( + endog_grid, + value, + policy, + jump_thresh=jump_thresh, + n_points_to_scan=n_points_to_scan, + ) + + # Set default value of final grid size to 1.2 times current if not defined + n_final_wealth_grid = ( + int(1.2 * (len(policy))) if n_final_wealth_grid is None else n_final_wealth_grid ) # Fill array with nans to fit 10% extra grid points - endog_grid_refined_with_nans = np.empty(int(1.1 * n_grid_wealth)) - policy_refined_with_nans = np.empty(int(1.1 * n_grid_wealth)) - value_refined_with_nans = np.empty(int(1.1 * n_grid_wealth)) + endog_grid_refined_with_nans = np.empty(n_final_wealth_grid) + policy_refined_with_nans = np.empty(n_final_wealth_grid) + value_refined_with_nans = np.empty(n_final_wealth_grid) endog_grid_refined_with_nans[:] = np.nan policy_refined_with_nans[:] = np.nan value_refined_with_nans[:] = np.nan @@ -125,13 +141,12 @@ def fast_upper_envelope_wrapper( @njit -def fast_upper_envelope( +def fues_numba_unconstrained( endog_grid: np.ndarray, value: np.ndarray, policy: np.ndarray, - exog_grid: np.ndarray, - jump_thresh: Optional[float] = 2, - lower_bound_wealth: Optional[float] = 1e-10, + jump_thresh=2, + n_points_to_scan=10, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Remove suboptimal points from the endogenous grid, policy, and value function. @@ -145,40 +160,27 @@ def fast_upper_envelope( exog_grid (np.ndarray): 1d array containing the exogenous wealth grid of shape (n_grid_wealth + 1,). jump_thresh (float): Jump detection threshold. - lower_bound_wealth (float): Lower bound on wealth. - Returns: tuple: - - endog_grid_refined (np.ndarray): 1d array containing the refined endogenous - wealth grid of shape (n_grid_clean,), which maps only to the optimal points - in the value function. - - value_refined (np.ndarray): 1d array containing the refined value function - of shape (n_grid_clean,). Overlapping segments have been removed and only - the optimal points are kept. - - policy_refined (np.ndarray): 1d array containing the refined policy function - of shape (n_grid_clean,). Overlapping segments have been removed and only - the optimal points are kept. + - endog_grid_refined (np.ndarray): 1d array of shape (n_final_wealth_grid,) + containing the refined endogenous wealth grid. + - policy_refined (np.ndarray): 1d array of shape (n_final_wealth_grid,) + containing refined consumption policy. + - value_refined (np.ndarray): 1d array of shape (n_final_wealth_grid,) + containing refined value function. """ - # TODO: determine locations where endogenous grid points are # noqa: T000 - # equal to the lower bound - mask = endog_grid <= lower_bound_wealth - if np.any(mask): - max_value_lower_bound = np.nanmax(value[mask]) - mask &= value < max_value_lower_bound - value[mask] = np.nan endog_grid = endog_grid[np.where(~np.isnan(value))[0]] policy = policy[np.where(~np.isnan(value))] - exog_grid = exog_grid[np.where(~np.isnan(value))[0]] value = value[np.where(~np.isnan(value))] idx_sort = np.argsort(endog_grid, kind="mergesort") value = np.take(value, idx_sort) policy = np.take(policy, idx_sort) - exog_grid = np.take(exog_grid, idx_sort) endog_grid = np.take(endog_grid, idx_sort) + exog_grid = endog_grid - policy ( value_clean_with_nans, @@ -190,7 +192,7 @@ def fast_upper_envelope( policy=policy, exog_grid=exog_grid, jump_thresh=jump_thresh, - n_points_to_scan=10, + n_points_to_scan=n_points_to_scan, ) endog_grid_refined = endog_grid_clean_with_nans[ @@ -208,7 +210,7 @@ def scan_value_function( value: np.ndarray, policy: np.ndarray, exog_grid: np.ndarray, - jump_thresh: float, + jump_thresh: Optional[float] = 2, n_points_to_scan: Optional[int] = 0, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Scan the value function to remove suboptimal points and add intersection points. @@ -732,16 +734,15 @@ def _append_index(x_array: np.ndarray, m: int): return x_array +@njit def _augment_grids( endog_grid: np.ndarray, value: np.ndarray, policy: np.ndarray, - expected_value_zero_savings: float, min_wealth_grid: float, - n_grid_wealth: int, - utility_function: Callable, - utility_kwargs: Dict[str, float], - discount_factor: float, + n_constrained_points_to_add: int, + value_function: Callable, + value_function_args: Tuple, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Extends the endogenous wealth grid, value, and policy functions to the left. @@ -778,14 +779,13 @@ def _augment_grids( """ grid_points_to_add = np.linspace( - min_wealth_grid, endog_grid[0], n_grid_wealth // 10 + min_wealth_grid, endog_grid[0], n_constrained_points_to_add + 1 )[:-1] - utility = utility_function( - consumption=grid_points_to_add, - **utility_kwargs, - ) - values_to_add = utility + discount_factor * expected_value_zero_savings + values_to_add = np.empty_like(grid_points_to_add) + + for i, grid_point in enumerate(grid_points_to_add): + values_to_add[i] = value_function(grid_point, *value_function_args) grid_augmented = np.append(grid_points_to_add, endog_grid) value_augmented = np.append(values_to_add, value) diff --git a/tests/test_upper_envelope_jax.py b/tests/test_fues_jax.py similarity index 82% rename from tests/test_upper_envelope_jax.py rename to tests/test_fues_jax.py index 422ae16..764e454 100644 --- a/tests/test_upper_envelope_jax.py +++ b/tests/test_fues_jax.py @@ -2,21 +2,20 @@ from pathlib import Path from typing import Dict +import jax import jax.numpy as jnp import numpy as np import pytest +import upper_envelope as upenv from numpy.testing import assert_array_almost_equal as aaae from upper_envelope.fues_jax.check_and_scan_funcs import back_and_forward_scan_wrapper -from upper_envelope.fues_jax.fues_jax import fast_upper_envelope -from upper_envelope.fues_jax.fues_jax import ( - fast_upper_envelope_wrapper, -) -from upper_envelope.fues_numba import fues_numba as fues_nb from tests.utils.interpolation import interpolate_policy_and_value_on_wealth_grid from tests.utils.interpolation import linear_interpolation_with_extrapolation from tests.utils.upper_envelope_fedor import upper_envelope +jax.config.update("jax_enable_x64", True) + # Obtain the test directory of the package. TEST_DIR = Path(__file__).parent @@ -89,18 +88,22 @@ def test_fast_upper_envelope_wrapper(period, setup_model): value_egm = np.genfromtxt( TEST_RESOURCES_DIR / f"upper_envelope_period_tests/val{period}.csv", delimiter=",", + dtype=float, ) policy_egm = np.genfromtxt( TEST_RESOURCES_DIR / f"upper_envelope_period_tests/pol{period}.csv", delimiter=",", + dtype=float, ) value_refined_fedor = np.genfromtxt( TEST_RESOURCES_DIR / f"upper_envelope_period_tests/expec_val{period}.csv", delimiter=",", + dtype=float, ) policy_refined_fedor = np.genfromtxt( TEST_RESOURCES_DIR / f"upper_envelope_period_tests/expec_pol{period}.csv", delimiter=",", + dtype=float, ) policy_expected = policy_refined_fedor[ :, ~np.isnan(policy_refined_fedor).any(axis=0) @@ -112,26 +115,31 @@ def test_fast_upper_envelope_wrapper(period, setup_model): params, _exog_savings_grid, state_choice_vars = setup_model - utility_kwargs = { + value_function_kwargs = { "choice": state_choice_vars["choice"], "params": params, } + + def value_func(consumption, choice, params): + return ( + utility_crra(consumption, choice, params) + params["beta"] * value_egm[1, 0] + ) + ( endog_grid_refined, policy_refined, value_refined, - ) = fast_upper_envelope_wrapper( - endog_grid=policy_egm[0, 1:], - policy=policy_egm[1, 1:], - value=value_egm[1, 1:], + ) = upenv.fues_jax( + endog_grid=jnp.asarray(policy_egm[0, 1:]), + policy=jnp.asarray(policy_egm[1, 1:]), + value=jnp.asarray(value_egm[1, 1:]), expected_value_zero_savings=value_egm[1, 0], - utility_function=utility_crra, - utility_kwargs=utility_kwargs, - disc_factor=params["beta"], + value_function=value_func, + value_function_kwargs=value_function_kwargs, ) wealth_max_to_test = np.max(endog_grid_refined[~np.isnan(endog_grid_refined)]) + 100 - wealth_grid_to_test = jnp.linspace( + wealth_grid_to_test = np.linspace( endog_grid_refined[1], wealth_max_to_test, 1000, dtype=float ) @@ -149,8 +157,8 @@ def test_fast_upper_envelope_wrapper(period, setup_model): ) = interpolate_policy_and_value_on_wealth_grid( wealth_beginning_of_period=wealth_grid_to_test, endog_wealth_grid=endog_grid_refined, - policy=policy_refined, - value_grid=value_refined, + policy_grid=policy_refined, + value_function_grid=value_refined, ) aaae(value_calc_interp, value_expec_interp) @@ -166,23 +174,21 @@ def test_fast_upper_envelope_against_numba(setup_model): ) _params, exog_savings_grid, state_choice_vars = setup_model - endog_grid_org, value_org, policy_org = fues_nb.fast_upper_envelope( + endog_grid_org, value_org, policy_org = upenv.fues_numba_unconstrained( endog_grid=policy_egm[0], value=value_egm[1], policy=policy_egm[1], - exog_grid=np.append(0, exog_savings_grid), ) ( endog_grid_refined, value_refined, policy_refined, - ) = fast_upper_envelope( + ) = jax.jit(upenv.fues_jax_unconstrained)( endog_grid=policy_egm[0, 1:], value=value_egm[1, 1:], policy=policy_egm[1, 1:], expected_value_zero_savings=value_egm[1, 0], - num_iter=int(1.2 * value_egm.shape[1]), ) wealth_max_to_test = np.max(endog_grid_refined[~np.isnan(endog_grid_refined)]) + 100 @@ -196,8 +202,8 @@ def test_fast_upper_envelope_against_numba(setup_model): ) = interpolate_policy_and_value_on_wealth_grid( wealth_beginning_of_period=wealth_grid_to_test, endog_wealth_grid=endog_grid_refined, - policy=policy_refined, - value_grid=value_refined, + policy_grid=policy_refined, + value_function_grid=value_refined, ) ( @@ -206,8 +212,8 @@ def test_fast_upper_envelope_against_numba(setup_model): ) = interpolate_policy_and_value_on_wealth_grid( wealth_beginning_of_period=wealth_grid_to_test, endog_wealth_grid=endog_grid_org, - policy=policy_org, - value_grid=value_org, + policy_grid=policy_org, + value_function_grid=value_org, ) aaae(value_calc_interp_calc, value_calc_interp_org) @@ -241,27 +247,27 @@ def test_fast_upper_envelope_against_fedor(period, setup_model): ~np.isnan(_value_fedor).any(axis=0), ] - utility_kwargs = { - "choice": state_choice_vec["choice"], - "params": params, - } + def value_func(consumption, choice, params): + return ( + utility_crra(consumption, choice, params) + params["beta"] * value_egm[1, 0] + ) ( endog_grid_fues, policy_fues, value_fues, - ) = fast_upper_envelope_wrapper( - endog_grid=policy_egm[0, 1:], - policy=policy_egm[1, 1:], - value=value_egm[1, 1:], + ) = upenv.fues_jax( + endog_grid=jnp.asarray(policy_egm[0, 1:]), + policy=jnp.asarray(policy_egm[1, 1:]), + value=jnp.asarray(value_egm[1, 1:]), expected_value_zero_savings=value_egm[1, 0], - utility_function=utility_crra, - utility_kwargs=utility_kwargs, - disc_factor=params["beta"], + value_function=value_func, + value_function_args=(state_choice_vec["choice"], params), + n_constrained_points_to_add=len(policy_egm[0, 1:]) // 10, ) wealth_max_to_test = np.max(endog_grid_fues[~np.isnan(endog_grid_fues)]) + 100 - wealth_grid_to_test = np.linspace(endog_grid_fues[1], wealth_max_to_test, 1000) + wealth_grid_to_test = jnp.linspace(endog_grid_fues[1], wealth_max_to_test, 1000) value_expec_interp = linear_interpolation_with_extrapolation( x_new=wealth_grid_to_test, x=value_expected[0], y=value_expected[1] @@ -273,8 +279,8 @@ def test_fast_upper_envelope_against_fedor(period, setup_model): policy_interp, value_interp = interpolate_policy_and_value_on_wealth_grid( wealth_beginning_of_period=wealth_grid_to_test, endog_wealth_grid=endog_grid_fues, - policy=policy_fues, - value_grid=value_fues, + policy_grid=policy_fues, + value_function_grid=value_fues, ) aaae(value_interp, value_expec_interp) aaae(policy_interp, policy_expec_interp) @@ -290,8 +296,8 @@ def test_back_and_forward_scan_wrapper_direction_flag(): endog_grid_to_scan_from=1.2, policy_to_scan_from=0.7, endog_grid=1, - value=np.arange(2, 5), - policy=np.arange(1, 4), + value=jnp.arange(2, 5), + policy=jnp.arange(1, 4), idx_to_scan_from=2, n_points_to_scan=3, is_scan_needed=False, diff --git a/tests/test_upper_envelope_numba.py b/tests/test_fues_numba.py similarity index 86% rename from tests/test_upper_envelope_numba.py rename to tests/test_fues_numba.py index 461e892..6b4ea3f 100644 --- a/tests/test_upper_envelope_numba.py +++ b/tests/test_fues_numba.py @@ -3,9 +3,8 @@ import numpy as np import pytest +import upper_envelope as upenv from numpy.testing import assert_array_almost_equal as aaae -from upper_envelope.fues_numba.fues_numba import fast_upper_envelope -from upper_envelope.fues_numba.fues_numba import fast_upper_envelope_wrapper from tests.utils.fast_upper_envelope_org import fast_upper_envelope_wrapper_org from tests.utils.interpolation import interpolate_single_policy_and_value_on_wealth_grid @@ -30,7 +29,7 @@ def utility_crra(consumption: np.array, choice: int, params: dict) -> np.array: (ii) of shape (n_grid_wealth,) when called by :func:`~dcgm.call_egm_step.get_current_period_value`. choice (int): Choice of the agent, e.g. 0 = "retirement", 1 = "working". - params_dict (dict): Dictionary containing model parameters. + params (dict): Dictionary containing model parameters. Relevant here is the CRRA coefficient theta. Returns: @@ -89,20 +88,18 @@ def test_fast_upper_envelope_wrapper(period, setup_model): params, state_choice_vec, _exog_savings_grid = setup_model - utility_kwargs = { - "choice": state_choice_vec["choice"], - "params": params, - } + def value_func(consumption, choice, params): + return ( + utility_crra(consumption, choice, params) + params["beta"] * value_egm[1, 0] + ) - endog_grid_refined, policy_refined, value_refined = fast_upper_envelope_wrapper( + endog_grid_refined, policy_refined, value_refined = upenv.fues_numba( endog_grid=policy_egm[0, 1:], policy=policy_egm[1, 1:], value=value_egm[1, 1:], expected_value_zero_savings=value_egm[1, 0], - exog_grid=_exog_savings_grid, - utility_function=utility_crra, - utility_kwargs=utility_kwargs, - discount_factor=params["beta"], + value_function=value_func, + value_function_args=(state_choice_vec["choice"], params), ) wealth_max_to_test = np.max(endog_grid_refined[~np.isnan(endog_grid_refined)]) + 100 @@ -125,7 +122,7 @@ def test_fast_upper_envelope_wrapper(period, setup_model): wealth_beginning_of_period=wealth_grid_to_test, endog_wealth_grid=endog_grid_refined, policy_grid=policy_refined, - value_grid=value_refined, + value_function_grid=value_refined, ) aaae(value_calc_interp, value_expec_interp) @@ -142,11 +139,10 @@ def test_fast_upper_envelope_against_org_fues(setup_model): _params, state_choice_vec, exog_savings_grid = setup_model - endog_grid_refined, value_refined, policy_refined = fast_upper_envelope( + endog_grid_refined, value_refined, policy_refined = upenv.fues_numba_unconstrained( endog_grid=policy_egm[0], value=value_egm[1], policy=policy_egm[1], - exog_grid=np.append(0, exog_savings_grid), ) endog_grid_org, policy_org, value_org = fast_upper_envelope_wrapper_org( @@ -193,20 +189,19 @@ def test_fast_upper_envelope_against_fedor(period, setup_model): :, ~np.isnan(_value_fedor).any(axis=0), ] - utility_kwargs = { - "choice": state_choice_vec["choice"], - "params": params, - } - endog_grid_fues, policy_fues, value_fues = fast_upper_envelope_wrapper( + def value_func(consumption, choice, params): + return ( + utility_crra(consumption, choice, params) + params["beta"] * value_egm[1, 0] + ) + + endog_grid_fues, policy_fues, value_fues = upenv.fues_numba( endog_grid=policy_egm[0, 1:], policy=policy_egm[1, 1:], value=value_egm[1, 1:], - exog_grid=np.append(0, exog_savings_grid), expected_value_zero_savings=value_egm[1, 0], - utility_function=utility_crra, - utility_kwargs=utility_kwargs, - discount_factor=params["beta"], + value_function=value_func, + value_function_args=(state_choice_vec["choice"], params), ) wealth_max_to_test = np.max(endog_grid_fues[~np.isnan(endog_grid_fues)]) + 100 @@ -223,7 +218,7 @@ def test_fast_upper_envelope_against_fedor(period, setup_model): wealth_beginning_of_period=wealth_grid_to_test, endog_wealth_grid=endog_grid_fues, policy_grid=policy_fues, - value_grid=value_fues, + value_function_grid=value_fues, ) aaae(value_interp, value_expec_interp) aaae(policy_interp, policy_expec_interp) diff --git a/tests/utils/interpolation.py b/tests/utils/interpolation.py index 540c4fa..4cfdbcb 100644 --- a/tests/utils/interpolation.py +++ b/tests/utils/interpolation.py @@ -1,54 +1,59 @@ +from typing import Tuple + import jax.numpy as jnp import numpy as np def interpolate_policy_and_value_on_wealth_grid( - wealth_beginning_of_period: jnp.ndarray, - endog_wealth_grid: jnp.ndarray, - policy: jnp.ndarray, - value_grid: jnp.ndarray, + wealth_beginning_of_period: np.ndarray | jnp.ndarray, + endog_wealth_grid: np.ndarray | jnp.ndarray, + policy_grid: np.ndarray | jnp.ndarray, + value_function_grid: np.ndarray | jnp.ndarray, ): """Interpolate policy and value functions on the wealth grid. Args: - wealth_beginning_of_period (jnp.ndarray): 1d array of shape (n,) containing the + wealth_beginning_of_period (np.ndarray): 1d array of shape (n,) containing the begin of period wealth. - endog_wealth_grid (jnp.array): 1d array of shape (n,) containing the endogenous + endog_wealth_grid (np.array): 1d array of shape (n,) containing the endogenous wealth grid. - policy_left_grid (jnp.ndarray): 1d array of shape (n,) containing the - left policy function corresponding to the endogenous wealth grid. - policy_right_grid (jnp.ndarray): 1d array of shape (n,) containing the - left policy function corresponding to the endogenous wealth grid. - value_grid (jnp.ndarray): 1d array of shape (n,) containing the value function - values corresponding to the endogenous wealth grid. + policy_grid (np.ndarray): 1d array of shape (n,) containing the + policy function corresponding to the endogenous wealth grid. + value_function_grid (np.ndarray): 1d array of shape (n,) containing the value + function corresponding to the endogenous wealth grid. Returns: tuple: - - policy_new (jnp.ndarray): 1d array of shape (n,) containing the interpolated + - policy_new (np.ndarray): 1d array of shape (n,) containing the interpolated policy function values corresponding to the begin of period wealth. - - value_new (jnp.ndarray): 1d array of shape (n,) containing the interpolated + - value_new (np.ndarray): 1d array of shape (n,) containing the interpolated value function values corresponding to the begin of period wealth. """ + # Make sure that these are numpy arrays + wealth_beginning_of_period = np.asarray(wealth_beginning_of_period) + endog_wealth_grid = np.asarray(endog_wealth_grid) + policy_grid = np.asarray(policy_grid) + value_function_grid = np.asarray(value_function_grid) ind_high, ind_low = get_index_high_and_low( x=endog_wealth_grid, x_new=wealth_beginning_of_period ) - wealth_low = jnp.take(endog_wealth_grid, ind_low) - wealth_high = jnp.take(endog_wealth_grid, ind_high) + wealth_low = np.take(endog_wealth_grid, ind_low) + wealth_high = np.take(endog_wealth_grid, ind_high) policy_new = linear_interpolation_formula( - y_high=jnp.take(policy, ind_high), - y_low=jnp.take(policy, ind_low), + y_high=np.take(policy_grid, ind_high), + y_low=np.take(policy_grid, ind_low), x_high=wealth_high, x_low=wealth_low, x_new=wealth_beginning_of_period, ) value_new = linear_interpolation_formula( - y_high=jnp.take(value_grid, ind_high), - y_low=jnp.take(value_grid, ind_low), + y_high=np.take(value_function_grid, ind_high), + y_low=np.take(value_function_grid, ind_low), x_high=wealth_high, x_low=wealth_low, x_new=wealth_beginning_of_period, @@ -58,10 +63,10 @@ def interpolate_policy_and_value_on_wealth_grid( def interpolate_single_policy_and_value_on_wealth_grid( - wealth_beginning_of_period: jnp.ndarray, - endog_wealth_grid: jnp.ndarray, - policy_grid: jnp.ndarray, - value_grid: jnp.ndarray, + wealth_beginning_of_period: np.ndarray | jnp.ndarray, + endog_wealth_grid: np.ndarray | jnp.ndarray, + policy_grid: np.ndarray | jnp.ndarray, + value_function_grid: np.ndarray | jnp.ndarray, ): """Interpolate policy and value functions on the wealth grid. @@ -70,44 +75,47 @@ def interpolate_single_policy_and_value_on_wealth_grid( in fast_upper_envelope.py. Args: - wealth_beginning_of_period (jnp.ndarray): 1d array of shape (n,) containing the + wealth_beginning_of_period (np.ndarray): 1d array of shape (n,) containing the begin of period wealth. - endog_wealth_grid (jnp.array): 1d array of shape (n,) containing the endogenous + endog_wealth_grid (np.ndarray): 1d array of shape (n,) containing the endogenous wealth grid. - policy_left_grid (jnp.ndarray): 1d array of shape (n,) containing the - left policy function corresponding to the endogenous wealth grid. - policy_right_grid (jnp.ndarray): 1d array of shape (n,) containing the - left policy function corresponding to the endogenous wealth grid. - value_grid (jnp.ndarray): 1d array of shape (n,) containing the value function - values corresponding to the endogenous wealth grid. + policy_grid (np.ndarray): 1d array of shape (n,) containing the + policy function corresponding to the endogenous wealth grid. + value_function_grid (np.ndarray): 1d array of shape (n,) containing the value + function corresponding to the endogenous wealth grid. Returns: tuple: - - policy_new (jnp.ndarray): 1d array of shape (n,) containing the interpolated + - policy_new (np.ndarray): 1d array of shape (n,) containing the interpolated policy function values corresponding to the begin of period wealth. - - value_new (jnp.ndarray): 1d array of shape (n,) containing the interpolated + - value_new (np.ndarray): 1d array of shape (n,) containing the interpolated value function values corresponding to the begin of period wealth. """ + # Make sure that these are numpy arrays + wealth_beginning_of_period = np.asarray(wealth_beginning_of_period) + endog_wealth_grid = np.asarray(endog_wealth_grid) + policy_grid = np.asarray(policy_grid) + value_function_grid = np.asarray(value_function_grid) ind_high, ind_low = get_index_high_and_low( x=endog_wealth_grid, x_new=wealth_beginning_of_period ) - wealth_low = jnp.take(endog_wealth_grid, ind_low) - wealth_high = jnp.take(endog_wealth_grid, ind_high) + wealth_low = np.take(endog_wealth_grid, ind_low) + wealth_high = np.take(endog_wealth_grid, ind_high) policy_new = linear_interpolation_formula( - y_high=jnp.take(policy_grid, ind_high), - y_low=jnp.take(policy_grid, ind_low), + y_high=np.take(policy_grid, ind_high), + y_low=np.take(policy_grid, ind_low), x_high=wealth_high, x_low=wealth_low, x_new=wealth_beginning_of_period, ) value_new = linear_interpolation_formula( - y_high=jnp.take(value_grid, ind_high), - y_low=jnp.take(value_grid, ind_low), + y_high=np.take(value_function_grid, ind_high), + y_low=np.take(value_function_grid, ind_low), x_high=wealth_high, x_low=wealth_low, x_new=wealth_beginning_of_period, @@ -117,11 +125,11 @@ def interpolate_single_policy_and_value_on_wealth_grid( def linear_interpolation_formula( - y_high: float | jnp.ndarray, - y_low: float | jnp.ndarray, - x_high: float | jnp.ndarray, - x_low: float | jnp.ndarray, - x_new: float | jnp.ndarray, + y_high: float | np.ndarray, + y_low: float | np.ndarray, + x_high: float | np.ndarray, + x_low: float | np.ndarray, + x_new: float | np.ndarray, ): """Linear interpolation formula.""" interpolate_dist = x_new - x_low @@ -131,7 +139,9 @@ def linear_interpolation_formula( return interpol_res -def get_index_high_and_low(x, x_new): +def get_index_high_and_low( + x: np.ndarray, x_new: np.ndarray | float +) -> Tuple[np.ndarray, np.ndarray]: """Get index of the highest value in x that is smaller than x_new. Args: @@ -143,8 +153,8 @@ def get_index_high_and_low(x, x_new): case of extrapolation last or first index of not nan element. """ - ind_high = jnp.searchsorted(x, x_new).clip(max=(x.shape[0] - 1), min=1) - ind_high -= jnp.isnan(x[ind_high]).astype(int) + ind_high = np.searchsorted(x, x_new).clip(max=(x.shape[0] - 1), min=1) + ind_high -= np.isnan(x[ind_high]).astype(int) return ind_high, ind_high - 1 @@ -153,7 +163,7 @@ def get_index_high_and_low(x, x_new): # ===================================================================================== -def linear_interpolation_with_extrapolation(x, y, x_new): +def linear_interpolation_with_extrapolation(x, y, x_new) -> np.ndarray: """Linear interpolation with extrapolation. Args: diff --git a/tests/utils/upper_envelope_fedor.py b/tests/utils/upper_envelope_fedor.py index 0676125..47b96c9 100644 --- a/tests/utils/upper_envelope_fedor.py +++ b/tests/utils/upper_envelope_fedor.py @@ -20,7 +20,7 @@ def upper_envelope( policy: np.ndarray, value: np.ndarray, exog_grid: np.ndarray, - state_choice_vec: np.ndarray, + state_choice_vec: Dict, params: Dict[str, float], compute_utility: Callable, ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/tox.ini b/tox.ini index a7aa642..fdea22c 100644 --- a/tox.ini +++ b/tox.ini @@ -59,4 +59,4 @@ markers = end_to_end: Flag for tests that cover the whole program. norecursedirs = .idea - .tox \ No newline at end of file + .tox