diff --git a/tests/sandbox/time_functions_jax.ipynb b/tests/sandbox/time_functions_jax.ipynb index 4c8b0a66..73244a84 100644 --- a/tests/sandbox/time_functions_jax.ipynb +++ b/tests/sandbox/time_functions_jax.ipynb @@ -2,15 +2,13 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, "id": "35ab80e3", "metadata": { "ExecuteTime": { - "end_time": "2024-06-20T14:03:27.083891148Z", - "start_time": "2024-06-20T14:03:26.534171812Z" + "end_time": "2024-11-25T15:44:09.704943Z", + "start_time": "2024-11-25T15:44:09.292524Z" } }, - "outputs": [], "source": [ "from jax import vmap, jit\n", "import pickle\n", @@ -19,8 +17,120 @@ "import yaml\n", "from functools import partial\n", "import jax.numpy as jnp\n", - "import numpy as np" - ] + "import numpy as np\n", + "from tests.utils.markov_simulator import markov_simulator" + ], + "outputs": [], + "execution_count": 1 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-25T16:03:54.707014Z", + "start_time": "2024-11-25T16:03:54.700726Z" + } + }, + "cell_type": "code", + "source": [ + "n_periods = 10\n", + "init_dist = np.array([0.5, 0.5])\n", + "trans_mat = np.array([[0.8, 0.2], [0.1, 0.9]])\n", + "\n", + "markov_simulator(n_periods, init_dist, trans_mat)" + ], + "id": "c2b7c16010b9ba85", + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0.5 , 0.5 ],\n", + " [0.45 , 0.55 ],\n", + " [0.415 , 0.585 ],\n", + " [0.3905 , 0.6095 ],\n", + " [0.37335 , 0.62665 ],\n", + " [0.361345 , 0.638655 ],\n", + " [0.3529415 , 0.6470585 ],\n", + " [0.34705905, 0.65294095],\n", + " [0.34294134, 0.65705866],\n", + " [0.34005893, 0.65994107]])" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 34 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-25T16:08:11.928721Z", + "start_time": "2024-11-25T16:08:11.908425Z" + } + }, + "cell_type": "code", + "source": [ + "n_agents = 100_000\n", + "current_agents_in_states = (np.ones(2) * n_agents / 2).astype(int)\n", + "for period in range(n_periods):\n", + " print(current_agents_in_states / n_agents)\n", + " next_period_agents_states = np.zeros(2, dtype=int)\n", + " for state in range(2):\n", + " agents_in_state = current_agents_in_states[state]\n", + " transition_draws = np.random.choice(\n", + " a=[0, 1], size=agents_in_state, p=trans_mat[state, :]\n", + " )\n", + " next_period_agents_states[1] += transition_draws.sum()\n", + " next_period_agents_states[0] += agents_in_state - transition_draws.sum()\n", + " current_agents_in_states = next_period_agents_states" + ], + "id": "ae676759dd2627d2", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0.5 0.5]\n", + "[0.4502 0.5498]\n", + "[0.4189 0.5811]\n", + "[0.39164 0.60836]\n", + "[0.37405 0.62595]\n", + "[0.35994 0.64006]\n", + "[0.35166 0.64834]\n", + "[0.34544 0.65456]\n", + "[0.34263 0.65737]\n", + "[0.34015 0.65985]\n" + ] + } + ], + "execution_count": 47 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-25T16:02:05.105098Z", + "start_time": "2024-11-25T16:02:05.100262Z" + } + }, + "cell_type": "code", + "source": [ + "trans_mat[0, :]" + ], + "id": "28dac7ec90b5d015", + "outputs": [ + { + "data": { + "text/plain": [ + "Array([0.8, 0.2], dtype=float32)" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 29 }, { "cell_type": "code",