Skip to content

Commit

Permalink
Updated Mnist example
Browse files Browse the repository at this point in the history
  • Loading branch information
TimRoith committed Mar 5, 2024
1 parent 1da8a51 commit 71bae48
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 34 deletions.
27 changes: 25 additions & 2 deletions experiments/nns/cbx_torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from collections import OrderedDict
import torch
from torch import logsumexp
from torch.func import functional_call, stack_module_state, vmap
from cbx.scheduler import bisection_solve, eff_sample_size_gap
import numpy as np

def norm_torch(x, axis, **kwargs):
return torch.linalg.norm(x, dim=axis, **kwargs)

def compute_consensus_torch(energy, x, alpha):
weights = - alpha * energy
coeffs = torch.exp(weights - torch.logsumexp(weights, dim=(-1,), keepdims=True))[...,None]
coeffs = torch.exp(weights - logsumexp(weights, dim=(-1,), keepdims=True))[...,None]
return (x * coeffs).sum(axis=-2, keepdims=True), energy.cpu().numpy()

def compute_polar_consensus_torch(energy, x, neg_log_eval, alpha = 1., kernel_factor = 1.):
Expand Down Expand Up @@ -59,4 +62,24 @@ def get_param_properties(models, pnames=None):
if len(pprop)>0:
a = pprop[next(reversed(pprop))][-1]
pprop[p] = (params[p][0,...].shape, a, a + params[p][0,...].numel())
return pprop
return pprop


class effective_sample_size:
def __init__(self, name = 'alpha', eta=.5, maximum=1e5, minimum=1e-5, solve_max_it = 15):
self.name = name
self.eta = eta
self.J_eff = 1.0
self.solve_max_it = solve_max_it
self.maximum = maximum
self.minimum = minimum

def update(self, dyn):
val = getattr(dyn, self.name)
device = val.device
val = bisection_solve(
eff_sample_size_gap(dyn.energy, self.eta),
self.minimum * np.ones((dyn.M,)), self.maximum * np.ones((dyn.M,)),
max_it = self.solve_max_it, thresh=1e-2
)
setattr(dyn, self.name, torch.tensor(val[:, None], device=device))
51 changes: 29 additions & 22 deletions experiments/nns/mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 13,
"id": "705a76af-42bb-4ca6-9f8c-f1f6d9dbad6a",
"metadata": {},
"outputs": [],
"source": [
"from models import Perceptron\n",
"from cbx_torch_utils import flatten_parameters, get_param_properties, eval_losses, norm_torch, compute_consensus_torch, normal_torch, eval_acc\n",
"from cbx_torch_utils import flatten_parameters, get_param_properties, eval_losses, norm_torch, compute_consensus_torch, normal_torch, eval_acc, effective_sample_size\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"N = 50\n",
"models = [Perceptron(sizes=[784,100,10]) for _ in range(N)]\n",
Expand All @@ -72,7 +72,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 14,
"id": "cac7a15f-8445-4eb3-ad00-bfdd15e201d8",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -112,7 +112,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 15,
"id": "b40c7cf2-4295-4bf3-8328-e47ba7ba16d5",
"metadata": {},
"outputs": [],
Expand All @@ -130,7 +130,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 27,
"id": "9e32f6f7-7937-4ceb-a043-93eb8e85f9b1",
"metadata": {},
"outputs": [],
Expand All @@ -146,7 +146,7 @@
" compute_consensus=compute_consensus_torch,\n",
" post_process = lambda dyn: resampling(dyn),\n",
" **kwargs)\n",
"sched = cbx.scheduler.multiply(factor=1.03, name='alpha')"
"sched = effective_sample_size(maximum=1e7, name='alpha')"
]
},
{
Expand All @@ -161,27 +161,42 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 28,
"id": "7ed4c926-f239-42b0-93c2-fcf48171adaa",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"------------------------------\n",
"Epoch: 1\n",
"Accuracy: 0.53329998254776\n",
"------------------------------\n",
"------------------------------\n",
"Epoch: 2\n",
"Accuracy: 0.644599974155426\n",
"------------------------------\n",
"------------------------------\n",
"Epoch: 3\n",
"Accuracy: 0.7069000005722046\n",
"------------------------------\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[10], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m e \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m f\u001b[38;5;241m.\u001b[39mepochs \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m10\u001b[39m:\n\u001b[0;32m----> 3\u001b[0m \u001b[43mdyn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m sched\u001b[38;5;241m.\u001b[39mupdate(dyn)\n\u001b[1;32m 5\u001b[0m f\u001b[38;5;241m.\u001b[39mset_batch()\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/dynamics/pdyn.py:296\u001b[0m, in \u001b[0;36mParticleDynamic.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 282\u001b[0m \u001b[38;5;124;03mExecute a step in the dynamic.\u001b[39;00m\n\u001b[1;32m 283\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 293\u001b[0m \u001b[38;5;124;03m None\u001b[39;00m\n\u001b[1;32m 294\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 295\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpre_step()\n\u001b[0;32m--> 296\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minner_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 297\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_step()\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/dynamics/cbo.py:58\u001b[0m, in \u001b[0;36mCBO.inner_step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdrift \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mx[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparticle_idx] \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconsensus\n\u001b[1;32m 57\u001b[0m \u001b[38;5;66;03m# compute noise\u001b[39;00m\n\u001b[0;32m---> 58\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39ms \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msigma \u001b[38;5;241m*\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnoise\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m# update particle positions\u001b[39;00m\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mx[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparticle_idx] \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mx[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparticle_idx] \u001b[38;5;241m-\u001b[39m\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcorrection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlamda \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdt \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdrift) \u001b[38;5;241m+\u001b[39m\n\u001b[1;32m 64\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39ms)\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/dynamics/pdyn.py:845\u001b[0m, in \u001b[0;36mCBXDynamic.noise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 834\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnoise\u001b[39m(\u001b[38;5;28mself\u001b[39m, ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ArrayLike:\n\u001b[1;32m 835\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 836\u001b[0m \u001b[38;5;124;03m Calculate the noise vector. Here, we use the callable ``noise_callable``, which takes the dynamic as an input via ``self``.\u001b[39;00m\n\u001b[1;32m 837\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 843\u001b[0m \u001b[38;5;124;03m ndarray: The noise vector.\u001b[39;00m\n\u001b[1;32m 844\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 845\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnoise_callable\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/noise.py:133\u001b[0m, in \u001b[0;36manisotropic_noise.__call__\u001b[0;34m(self, dyn)\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, dyn) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ArrayLike:\n\u001b[0;32m--> 133\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m np\u001b[38;5;241m.\u001b[39msqrt(dyn\u001b[38;5;241m.\u001b[39mdt) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdyn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdrift\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/noise.py:156\u001b[0m, in \u001b[0;36manisotropic_noise.sample\u001b[0;34m(self, drift)\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msample\u001b[39m(\u001b[38;5;28mself\u001b[39m, drift: ArrayLike) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ArrayLike:\n\u001b[1;32m 136\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 137\u001b[0m \n\u001b[1;32m 138\u001b[0m \u001b[38;5;124;03m This function implements the anisotropic noise model. From the drift :math:`d = x - c(x)`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[38;5;124;03m which motivates the name **anisotropic**.\u001b[39;00m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 156\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msampler\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdrift\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m*\u001b[39m drift\n",
"File \u001b[0;32m~/Documents/CBXpy/experiments/nns/cbx_torch_utils.py:21\u001b[0m, in \u001b[0;36mnormal_torch.<locals>._normal_torch\u001b[0;34m(mean, std, size)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_normal_torch\u001b[39m(mean, std, size):\n\u001b[0;32m---> 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnormal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmean\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstd\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(device)\n",
"Cell \u001b[0;32mIn[28], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m f\u001b[38;5;241m.\u001b[39mepochs \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m10\u001b[39m:\n\u001b[1;32m 3\u001b[0m dyn\u001b[38;5;241m.\u001b[39mstep()\n\u001b[0;32m----> 4\u001b[0m \u001b[43msched\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdyn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m f\u001b[38;5;241m.\u001b[39mset_batch()\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m e \u001b[38;5;241m!=\u001b[39m f\u001b[38;5;241m.\u001b[39mepochs:\n",
"File \u001b[0;32m~/Documents/CBXpy/experiments/nns/cbx_torch_utils.py:80\u001b[0m, in \u001b[0;36meffective_sample_size.update\u001b[0;34m(self, dyn)\u001b[0m\n\u001b[1;32m 78\u001b[0m val \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(dyn, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname)\n\u001b[1;32m 79\u001b[0m device \u001b[38;5;241m=\u001b[39m val\u001b[38;5;241m.\u001b[39mdevice\n\u001b[0;32m---> 80\u001b[0m val \u001b[38;5;241m=\u001b[39m \u001b[43mbisection_solve\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 81\u001b[0m \u001b[43m \u001b[49m\u001b[43meff_sample_size_gap\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdyn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menergy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meta\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 82\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mminimum\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mones\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdyn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mM\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmaximum\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mones\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdyn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mM\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 83\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_it\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msolve_max_it\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mthresh\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-2\u001b[39;49m\n\u001b[1;32m 84\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 85\u001b[0m \u001b[38;5;28msetattr\u001b[39m(dyn, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname, torch\u001b[38;5;241m.\u001b[39mtensor(val[:, \u001b[38;5;28;01mNone\u001b[39;00m], device\u001b[38;5;241m=\u001b[39mdevice))\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/scheduler.py:206\u001b[0m, in \u001b[0;36mbisection_solve\u001b[0;34m(f, low, high, max_it, thresh, verbosity)\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m term:\n\u001b[1;32m 205\u001b[0m x \u001b[38;5;241m=\u001b[39m (low \u001b[38;5;241m+\u001b[39m high)\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m2\u001b[39m\n\u001b[0;32m--> 206\u001b[0m fx \u001b[38;5;241m=\u001b[39m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 207\u001b[0m gtzero \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mwhere(fx[idx] \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m)[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 208\u001b[0m ltzero \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mwhere(fx[idx] \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m)[\u001b[38;5;241m0\u001b[39m]\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/scheduler.py:180\u001b[0m, in \u001b[0;36meff_sample_size_gap.__call__\u001b[0;34m(self, alpha)\u001b[0m\n\u001b[1;32m 178\u001b[0m nom \u001b[38;5;241m=\u001b[39m logsumexp(\u001b[38;5;241m-\u001b[39malpha[:, \u001b[38;5;28;01mNone\u001b[39;00m] \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menergy, axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 179\u001b[0m denom \u001b[38;5;241m=\u001b[39m logsumexp(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m alpha[:, \u001b[38;5;28;01mNone\u001b[39;00m] \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menergy, axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m--> 180\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexp\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mnom\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mdenom\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meta\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mN\u001b[49m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
Expand All @@ -200,14 +215,6 @@
" print('Accuracy: ' + str(acc.item()))\n",
" print(30*'-')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "683a67e7-745f-4ada-afee-7421c0921aaf",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
11 changes: 1 addition & 10 deletions tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import cbx
from cbx.scheduler import multiply, scheduler, effective_sample_size
from cbx.scheduler import multiply, scheduler

def test_multiply_update():
'''Test if multiply scheduler updates params correctly'''
Expand All @@ -22,12 +21,4 @@ def test_multiply_maximum():
assert dyn.alpha == 2.0
assert dyn.sigma == 4.7

def test_effective_number_scheduler():
'''Test if effective number scheduler updates params correctly'''
x = np.ones((6,5,7))
dyn = cbx.dynamics.CBO(f=lambda x: np.sum(x**2), x=x, max_it=1, alpha=1.0, sigma=1.0)
sched = scheduler([effective_sample_size(name='alpha', maximum=20.0)])

dyn.optimize(sched=sched)
assert dyn.alpha.shape == (6,1)

0 comments on commit 71bae48

Please sign in to comment.