Skip to content

Commit

Permalink
added functionality to plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
TimRoith committed Feb 4, 2024
1 parent cc8b2c6 commit fc4bf32
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
21 changes: 16 additions & 5 deletions cbx/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def __init__(self,
eval_energy_1d = True,
objective_args = None,
particle_args = None,
cosensus_args = None):
cosensus_args = None,
drift_args = None):

self.dyn = dyn
self.d = dyn.d
Expand All @@ -87,7 +88,13 @@ def __init__(self,
self.objective_args = objective_args if objective_args is not None else {}
self.particle_args = particle_args if particle_args is not None else {}
self.cosensus_args = cosensus_args if cosensus_args is not None else {}

self.drift_args = {
'scale':1, 'scale_units':'xy', 'angles':'xy',
'width':0.001, 'color':'orange'}
if drift_args is not None:
for k in drift_args.keys():
self.drift_args[k] = drift_args[k]

xmin = self.objective_args.get('x_min', -1.)
xmax = self.objective_args.get('x_max', 1.)
ax.set_xlim([xmin, xmax])
Expand Down Expand Up @@ -215,8 +222,7 @@ def init_drift(self, x, dr, pidx):
x[pidx][..., self.dims[1]][self.num_run,:],
-dr[self.num_run, :, self.dims[0]],
-dr[self.num_run, :, self.dims[1]],
scale=1., scale_units='xy', angles='xy',
width=0.001,color='orange')
**self.drift_args)

def update(self, wait=0.1):
"""
Expand Down Expand Up @@ -415,7 +421,7 @@ def decorate_at_ind(self,i):
"""
self.ax.set_title('Iteration: ' + str(i))

def run_plots(self, freq=5, wait=0.1):
def run_plots(self, freq=5, wait=0.1, save_args=None):
"""
Visualizes the evolution of the dynamic over time, using the history
Expand All @@ -431,6 +437,11 @@ def run_plots(self, freq=5, wait=0.1):
self.plot_at_ind(i)
self.decorate_at_ind(i)
plt.pause(wait)
if save_args is not None:
plt.savefig(
save_args['fname'] + str(i) + '.png',
**{k:save_args[k] for k in save_args.keys() if k not in ['fname']}
)



22 changes: 12 additions & 10 deletions experiments/lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,21 @@
np.random.seed(420)
#%%
conf = {'alpha': 40.0,
'dt': 0.01,
'sigma': 8.1,#8,#5.1,#8.0,
'dt': 0.1,
'sigma': 1.,#8,#5.1,#8.0,
'lamda': 1.0,
'batch_args':{
'batch_size':200,
'batch_partial': False},
'd': 20,
'term_args':{'max_it': 1000},
'N': 1000,
'd': 2,
'term_args':{'max_it': 50},
'N': 50,
'M': 3,
'track_args': {'names':
['update_norm',
'energy','x',
'consensus',
'drift']},
'resampling': False,
'update_thresh': 0.002}

#%% Define the objective function
Expand All @@ -38,10 +37,10 @@ def f(x):
return np.linalg.norm(x, axis=-1)

#%% Define the initial positions of the particles
x = cbx.utils.init_particles(shape=(conf['M'], conf['N'], conf['d']), x_min=-3., x_max = 3.)
x = cbx.utils.init_particles(shape=(conf['M'], conf['N'], conf['d']), x_min=-2., x_max = 1.)

#%% Define the CBO algorithm
dyn = CBO(f, x=x, noise='anisotropic', f_dim='3D',
dyn = CBO(f, x=x, noise='isotropic', f_dim='3D',
**conf)
sched = scheduler([multiply(name='alpha', factor=1.1, maximum=1e15),
#multiply(name='sigma', factor=1.005, maximum=6.)
Expand All @@ -67,7 +66,10 @@ def f(x):
plt.close('all')
plotter = plot_dynamic_history(
dyn, dims=[0,1],
objective_args={'x_min':-3, 'x_max':3},
objective_args={'x_min':-3, 'x_max':3, 'cmap':'viridis',
'num_pts':300},
particle_args = {'s':50, 'c':'xkcd:sky', 'marker':'o'},
drift_args = {'color':'pink', 'width':0.003},
plot_consensus=True,
plot_drift=True)
plotter.run_plots(wait=0.05, freq=1)
plotter.run_plots(wait=0.5, freq=1,)

0 comments on commit fc4bf32

Please sign in to comment.