Skip to content

Commit

Permalink
Changes in consensus exposure
Browse files Browse the repository at this point in the history
  • Loading branch information
TimRoith committed Jan 23, 2024
1 parent cbe0694 commit 90b2838
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 24 deletions.
14 changes: 7 additions & 7 deletions cbx/dynamics/pdyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,8 +531,7 @@ def compute_mat_sqrt(A):
B = np.maximum(B,0.)
return V@(np.sqrt(B)[...,None]*V.transpose(0,2,1))

def compute_consensus_default(f, x, alpha):
energy = f(x) # update energy
def compute_consensus_default(energy, x, alpha):
weights = - alpha * energy
coeffs = np.exp(weights - logsumexp(weights, axis=(-1,), keepdims=True))[...,None]
problem_idx = np.where(np.abs(coeffs.sum(axis=-2)-1) > 0.1)[0]
Expand Down Expand Up @@ -835,8 +834,9 @@ def reset(self,):
self.init_history()
self.t = 0.

def eval_energy(self,):
self.energy = self.f(self.x)
def eval_f(self, x):
self.num_f_eval += np.ones(self.M, dtype=int) * x.shape[-2] # update number of function evaluations
return self.f(x)

def print_cur_state(self,):
if self.verbosity > 0:
Expand All @@ -846,7 +846,7 @@ def print_cur_state(self,):
if self.verbosity > 1:
print('Current alpha: ' + str(self.alpha))

def compute_consensus(self, x_batch) -> None:
def compute_consensus(self, x) -> None:
r"""Updates the weighted mean of the particles.
Parameters
Expand All @@ -860,8 +860,8 @@ def compute_consensus(self, x_batch) -> None:
"""
# evaluation of objective function on batch

self.num_f_eval += np.ones(self.M,dtype=int) * x_batch.shape[-2] # update number of function evaluations
return self._compute_consensus(self.f, x_batch, self.alpha)
energy = self.eval_f(x) # update energy
return self._compute_consensus(energy, x, self.alpha)



3 changes: 2 additions & 1 deletion cbx/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .particle_init import init_particles
from . import resampling


__all__ = ['init_particles']
__all__ = ['init_particles', 'resampling']
3 changes: 1 addition & 2 deletions experiments/nns/cbx_torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
def norm_torch(x, axis, **kwargs):
return torch.linalg.norm(x, dim=axis, **kwargs)

def compute_consensus_torch(f, x, alpha):
energy = f(x) # update energy
def compute_consensus_torch(energy, x, alpha):
weights = - alpha * energy
coeffs = torch.exp(weights - torch.logsumexp(weights, axis=(-1,), keepdims=True))[...,None]
return (x * coeffs).sum(axis=-2, keepdims=True), energy.cpu().numpy()
Expand Down
58 changes: 44 additions & 14 deletions experiments/nns/mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,19 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 9,
"id": "1b37649d-ae72-49dc-855e-5109b13d9433",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
Expand All @@ -16,7 +25,8 @@
"import torch\n",
"import torch.nn as nn\n",
"import torchvision\n",
"from cbx.noise import anisotropic_noise"
"from cbx.noise import anisotropic_noise\n",
"import cbx.utils.resampling as rsmp"
]
},
{
Expand Down Expand Up @@ -129,56 +139,76 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 14,
"id": "9e32f6f7-7937-4ceb-a043-93eb8e85f9b1",
"metadata": {},
"outputs": [],
"source": [
"f = objective(train_loader, N, device, model, pprop)\n",
"resamplings = [cbx.utils.resampling.loss_update_resampling(M=1, wait_thresh=40)]\n",
"resampling = rsmp.resampling([rsmp.loss_update_resampling(M=1, wait_thresh=40)], 1)\n",
"noise = anisotropic_noise(norm = norm_torch, sampler = normal_torch(device))\n",
"dyn = CBO(f, f_dim='3D', x=w[None,...], noise=noise, \n",
" resamplings=resamplings, \n",
"\n",
"dyn = CBO(f, f_dim='3D', x=w[None,...], noise=noise,\n",
" norm=norm_torch,\n",
" copy=torch.clone,\n",
" normal=normal_torch(device),\n",
" compute_consensus=compute_consensus_torch,\n",
" post_process = lambda *args: None,\n",
" post_process = lambda dyn: resampling(dyn),\n",
" **kwargs)\n",
"sched = cbx.scheduler.multiply(factor=1.03, name='alpha')"
]
},
{
"cell_type": "markdown",
"id": "8e59d1f1-c8b7-4d97-b8f5-32d41658495f",
"metadata": {},
"metadata": {
"tags": []
},
"source": [
"# Train the network"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 15,
"id": "7ed4c926-f239-42b0-93c2-fcf48171adaa",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"------------------------------\n",
"Epoch: 1\n",
"Accuracy: 0.5214999914169312\n",
"------------------------------\n",
"------------------------------\n",
"Epoch: 2\n",
"Accuracy: 0.6847000122070312\n",
"------------------------------\n",
"------------------------------\n",
"Epoch: 3\n",
"Accuracy: 0.7095999717712402\n",
"------------------------------\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m e \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m f\u001b[38;5;241m.\u001b[39mepochs \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m10\u001b[39m:\n\u001b[0;32m----> 3\u001b[0m \u001b[43mdyn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m sched\u001b[38;5;241m.\u001b[39mupdate(dyn)\n\u001b[1;32m 5\u001b[0m f\u001b[38;5;241m.\u001b[39mset_batch()\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/dynamics/pdyn.py:295\u001b[0m, in \u001b[0;36mParticleDynamic.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 281\u001b[0m \u001b[38;5;124;03mExecute a step in the dynamic.\u001b[39;00m\n\u001b[1;32m 282\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 292\u001b[0m \u001b[38;5;124;03m None\u001b[39;00m\n\u001b[1;32m 293\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 294\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpre_step()\n\u001b[0;32m--> 295\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minner_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 296\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_step()\n",
"Cell \u001b[0;32mIn[15], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m e \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m f\u001b[38;5;241m.\u001b[39mepochs \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m10\u001b[39m:\n\u001b[0;32m----> 3\u001b[0m \u001b[43mdyn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m sched\u001b[38;5;241m.\u001b[39mupdate(dyn)\n\u001b[1;32m 5\u001b[0m f\u001b[38;5;241m.\u001b[39mset_batch()\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/dynamics/pdyn.py:292\u001b[0m, in \u001b[0;36mParticleDynamic.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 278\u001b[0m \u001b[38;5;124;03mExecute a step in the dynamic.\u001b[39;00m\n\u001b[1;32m 279\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[38;5;124;03m None\u001b[39;00m\n\u001b[1;32m 290\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 291\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpre_step()\n\u001b[0;32m--> 292\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minner_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 293\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_step()\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/dynamics/cbo.py:57\u001b[0m, in \u001b[0;36mCBO.inner_step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menergy[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconsensus_idx] \u001b[38;5;241m=\u001b[39m energy\n\u001b[1;32m 56\u001b[0m \u001b[38;5;66;03m# compute noise\u001b[39;00m\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39ms \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msigma \u001b[38;5;241m*\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnoise\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;66;03m# update particle positions\u001b[39;00m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mx[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparticle_idx] \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mx[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparticle_idx] \u001b[38;5;241m-\u001b[39m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcorrection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlamda \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdt \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdrift) \u001b[38;5;241m+\u001b[39m\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39ms)\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/dynamics/pdyn.py:801\u001b[0m, in \u001b[0;36mCBXDynamic.noise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnoise\u001b[39m(\u001b[38;5;28mself\u001b[39m, ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ArrayLike:\n\u001b[1;32m 791\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 792\u001b[0m \u001b[38;5;124;03m Calculate the noise vector. Here, we use the callable ``noise_callable``, which takes the dynamic as an input via ``self``.\u001b[39;00m\n\u001b[1;32m 793\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 799\u001b[0m \u001b[38;5;124;03m ndarray: The noise vector.\u001b[39;00m\n\u001b[1;32m 800\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 801\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnoise_callable\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/dynamics/pdyn.py:790\u001b[0m, in \u001b[0;36mCBXDynamic.noise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 779\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnoise\u001b[39m(\u001b[38;5;28mself\u001b[39m, ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ArrayLike:\n\u001b[1;32m 780\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 781\u001b[0m \u001b[38;5;124;03m Calculate the noise vector. Here, we use the callable ``noise_callable``, which takes the dynamic as an input via ``self``.\u001b[39;00m\n\u001b[1;32m 782\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 788\u001b[0m \u001b[38;5;124;03m ndarray: The noise vector.\u001b[39;00m\n\u001b[1;32m 789\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 790\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnoise_callable\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/noise.py:133\u001b[0m, in \u001b[0;36manisotropic_noise.__call__\u001b[0;34m(self, dyn)\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, dyn) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ArrayLike:\n\u001b[0;32m--> 133\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m np\u001b[38;5;241m.\u001b[39msqrt(dyn\u001b[38;5;241m.\u001b[39mdt) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdyn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdrift\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Documents/CBXpy/cbx/noise.py:156\u001b[0m, in \u001b[0;36manisotropic_noise.sample\u001b[0;34m(self, drift)\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msample\u001b[39m(\u001b[38;5;28mself\u001b[39m, drift: ArrayLike) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ArrayLike:\n\u001b[1;32m 136\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 137\u001b[0m \n\u001b[1;32m 138\u001b[0m \u001b[38;5;124;03m This function implements the anisotropic noise model. From the drift :math:`d = x - c(x)`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[38;5;124;03m which motivates the name **anisotropic**.\u001b[39;00m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 156\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msampler\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdrift\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m*\u001b[39m drift\n",
"File \u001b[0;32m~/Documents/CBXpy/experiments/nns/cbx_torch_utils.py:16\u001b[0m, in \u001b[0;36mnormal_torch.<locals>._normal_torch\u001b[0;34m(mean, std, size)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_normal_torch\u001b[39m(mean, std, size):\n\u001b[0;32m---> 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnormal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmean\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstd\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(device)\n",
"File \u001b[0;32m~/Documents/CBXpy/experiments/nns/cbx_torch_utils.py:15\u001b[0m, in \u001b[0;36mnormal_torch.<locals>._normal_torch\u001b[0;34m(mean, std, size)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_normal_torch\u001b[39m(mean, std, size):\n\u001b[0;32m---> 15\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnormal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmean\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstd\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(device)\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
Expand Down

0 comments on commit 90b2838

Please sign in to comment.