Skip to content

Commit

Permalink
Fixed bug in resamplings
Browse files Browse the repository at this point in the history
  • Loading branch information
TimRoith committed Jan 19, 2024
1 parent 2dbdbdc commit cbe0694
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 45 deletions.
28 changes: 2 additions & 26 deletions cbx/dynamics/pdyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from ..utils.termination import Termination, check_energy, check_max_it, check_diff_tol, check_max_eval, check_max_time
from ..utils.history import track_x, track_energy, track_update_norm, track_consensus, track_drift, track_drift_mean
from ..utils.particle_init import init_particles
from ..utils.resampling import apply_resamplings
from cbx.utils.objective_handling import _promote_objective

#%%
from typing import Callable, Union, List
from typing import Callable, Union
from numpy.typing import ArrayLike
import numpy as np
from numpy.random import Generator, MT19937
Expand Down Expand Up @@ -125,8 +124,6 @@ class ParticleDynamic:
* 'x': The positions of the particles.
* 'update_norm': The norm of the particle update.
* 'energy': The energy of the system.
save_int : int, optional
The frequency of the saving of the data. The default is 1.
verbosity : int, optional
The verbosity level. The default is 1.
Expand Down Expand Up @@ -572,9 +569,6 @@ class CBXDynamic(ParticleDynamic):
The correction method. Default: 'no_correction'. One of 'no_correction', 'heavi_side', 'heavi_side_reg' or a Callable.
correction_eps: float, optional
The parameter :math:`\epsilon` for the regularized correction. Default: 1e-3.
resamplings: list, optional
List of callables that return indices of runs to resample.
Default: [].
Returns:
None
Expand All @@ -588,7 +582,6 @@ def __init__(self, f,
lamda: float = 1.0,
correction: Union[str, None] = 'no_correction',
correction_eps: float = 1e-3,
resamplings: List[Callable] = None,
update_thresh: float = 0.1,
compute_consensus: Callable = None,
**kwargs) -> None:
Expand All @@ -608,9 +601,6 @@ def __init__(self, f,

self.init_batch_idx(batch_args)

self.resamplings = resamplings
self.num_resampling = np.zeros((self.M,), dtype=int)

self.consensus = None #consensus point
self._compute_consensus = compute_consensus if compute_consensus is not None else compute_consensus_default

Expand Down Expand Up @@ -817,17 +807,6 @@ def update_covariance(self,) -> None:
D = self.drift[...,None] * self.drift[...,None,:]
D = np.sum(D * coeffs[..., None, None], axis = -3)
self.Cov_sqrt = compute_mat_sqrt(D)


def resample(self,) -> None:
idx = apply_resamplings(self, self.resamplings)
if len(idx)>0:
z = self.normal(0, 1., size=(len(idx), self.N, self.d))
self.x[idx, ...] += self.sigma * np.sqrt(self.dt) * z

self.num_resampling[idx] += 1
if self.verbosity > 0:
print('Resampled in runs ' + str(idx))

def pre_step(self,):
# save old positions
Expand All @@ -845,11 +824,8 @@ def post_step(self):

self.update_best_cur_particle()
self.update_best_particle()
self.post_process()
self.post_process(self)
self.track()

if self.resamplings is not None:
self.resample()

self.t += self.dt
self.it+=1
Expand Down
10 changes: 7 additions & 3 deletions cbx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,23 @@ class effective_number(param_update):
"""
def __init__(self, name = 'alpha', eta=1.0, maximum=1e5, factor=1.05):
def __init__(self, name = 'alpha', eta=1.0, maximum=1e5, factor=1.05,reeval=False):
super(effective_number, self).__init__(name = name, maximum=maximum)
if self.name != 'alpha':
warnings.warn('effective_number scheduler only works for alpha parameter! You specified name = {}!'.format(self.name), stacklevel=2)
self.eta = eta
self.J_eff = 1.0
self.factor=factor
self.reeval=reeval

def update(self, dyn):
val = getattr(dyn, self.name)

energy = dyn.f(dyn.x)
dyn.num_f_eval += np.ones(dyn.M, dtype=int) * dyn.x.shape[-1]
if self.reeval:
energy = dyn.f(dyn.x)
dyn.num_f_eval += np.ones(dyn.M, dtype=int) * dyn.x.shape[-1]
else:
energy = dyn.energy

term1 = logsumexp(-val * energy)
term2 = logsumexp(-2 * val * energy)
Expand Down
56 changes: 42 additions & 14 deletions cbx/utils/resampling.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,52 @@
import numpy as np
from typing import Callable, List

def apply_resamplings(dyn, resamplings: list):
def apply_resampling_default(dyn, idx):
z = dyn.normal(0, 1., size=(len(idx), dyn.N, dyn.d))
dyn.x[idx, ...] += dyn.sigma * np.sqrt(dyn.dt) * z

class resampling:
"""
Apply resamplings to a given dynamic
Resamplings from a list of callables
Parameters
----------
dyn
The dynamic object to apply resamplings to
resamplings
resamplings: list
The list of resamplings to apply. Each entry should be a callable that accepts exactly one argument (the dynamic object) and returns a one-dimensional
numpy array of indices.
Returns
-------
The indices of the runs to resample as a numpy array
apply: Callable
- ``dyn``: The dynmaic which the resampling is applied to.
- ``idx``: List of indices that are resampled.
The function that should be performed on a given dynamic for selected indices. This function has to have the signature apply(dyn,idx).
"""
def __init__(self, resamplings: List[Callable], M: int, apply:Callable = None):
self.resamplings = resamplings
self.M = M
self.num_resampling = np.zeros(M)
self.apply = apply if apply is not None else apply_resampling_default

return np.unique(np.concatenate([r(dyn) for r in resamplings]))

def __call__(self, dyn):
"""
Applies the resamplings to a given dynamic
Parameters
----------
dyn
The dynamic object to apply resamplings to
Returns
-------
None
"""
idx = np.unique(np.concatenate([r(dyn) for r in self.resamplings]))
if len(idx)>0:
self.apply(dyn, idx)
self.num_resampling[idx] += 1
if dyn.verbosity > 0:
print('Resampled in runs ' + str(idx))

class ensemble_update_resampling:
"""
Expand Down Expand Up @@ -69,8 +96,9 @@ def __init__(self, M:int, wait_thresh:int = 5):

def __call__(self,dyn):
self.wait += 1
self.wait[self.best_energy > dyn.best_energy] = 0
self.best_energy = dyn.best_energy
u_idx = self.best_energy > dyn.best_energy
self.wait[u_idx] = 0
self.best_energy[u_idx] = dyn.best_energy[u_idx]
idx = np.where(self.wait >= self.wait_thresh)[0]
self.wait = np.mod(self.wait, self.wait_thresh)
return idx
22 changes: 20 additions & 2 deletions experiments/nns/mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,31 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"id": "7ed4c926-f239-42b0-93c2-fcf48171adaa",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m e \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m f\u001b[38;5;241m.\u001b[39mepochs \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m10\u001b[39m:\n\u001b[0;32m----> 3\u001b[0m \u001b[43mdyn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m sched\u001b[38;5;241m.\u001b[39mupdate(dyn)\n\u001b[1;32m 5\u001b[0m f\u001b[38;5;241m.\u001b[39mset_batch()\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/dynamics/pdyn.py:295\u001b[0m, in \u001b[0;36mParticleDynamic.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 281\u001b[0m \u001b[38;5;124;03mExecute a step in the dynamic.\u001b[39;00m\n\u001b[1;32m 282\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 292\u001b[0m \u001b[38;5;124;03m None\u001b[39;00m\n\u001b[1;32m 293\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 294\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpre_step()\n\u001b[0;32m--> 295\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minner_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 296\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_step()\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/dynamics/cbo.py:57\u001b[0m, in \u001b[0;36mCBO.inner_step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menergy[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconsensus_idx] \u001b[38;5;241m=\u001b[39m energy\n\u001b[1;32m 56\u001b[0m \u001b[38;5;66;03m# compute noise\u001b[39;00m\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39ms \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msigma \u001b[38;5;241m*\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnoise\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;66;03m# update particle positions\u001b[39;00m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mx[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparticle_idx] \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mx[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparticle_idx] \u001b[38;5;241m-\u001b[39m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcorrection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlamda \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdt \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdrift) \u001b[38;5;241m+\u001b[39m\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39ms)\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/dynamics/pdyn.py:801\u001b[0m, in \u001b[0;36mCBXDynamic.noise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnoise\u001b[39m(\u001b[38;5;28mself\u001b[39m, ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ArrayLike:\n\u001b[1;32m 791\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 792\u001b[0m \u001b[38;5;124;03m Calculate the noise vector. Here, we use the callable ``noise_callable``, which takes the dynamic as an input via ``self``.\u001b[39;00m\n\u001b[1;32m 793\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 799\u001b[0m \u001b[38;5;124;03m ndarray: The noise vector.\u001b[39;00m\n\u001b[1;32m 800\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 801\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnoise_callable\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/noise.py:133\u001b[0m, in \u001b[0;36manisotropic_noise.__call__\u001b[0;34m(self, dyn)\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, dyn) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ArrayLike:\n\u001b[0;32m--> 133\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m np\u001b[38;5;241m.\u001b[39msqrt(dyn\u001b[38;5;241m.\u001b[39mdt) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdyn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdrift\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/noise.py:156\u001b[0m, in \u001b[0;36manisotropic_noise.sample\u001b[0;34m(self, drift)\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msample\u001b[39m(\u001b[38;5;28mself\u001b[39m, drift: ArrayLike) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ArrayLike:\n\u001b[1;32m 136\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 137\u001b[0m \n\u001b[1;32m 138\u001b[0m \u001b[38;5;124;03m This function implements the anisotropic noise model. From the drift :math:`d = x - c(x)`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[38;5;124;03m which motivates the name **anisotropic**.\u001b[39;00m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 156\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msampler\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdrift\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m*\u001b[39m drift\n",
"File \u001b[0;32m~/Documents/CBXpy/experiments/nns/cbx_torch_utils.py:16\u001b[0m, in \u001b[0;36mnormal_torch.<locals>._normal_torch\u001b[0;34m(mean, std, size)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_normal_torch\u001b[39m(mean, std, size):\n\u001b[0;32m---> 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnormal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmean\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstd\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(device)\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"e = 0\n",
"while f.epochs < 10:\n",
Expand Down

0 comments on commit cbe0694

Please sign in to comment.