Skip to content

Commit

Permalink
Allowing multi dim particles
Browse files Browse the repository at this point in the history
  • Loading branch information
TimRoith committed Feb 22, 2024
1 parent 0142031 commit 553897c
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions cbx/dynamics/pdyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,8 +533,7 @@ def compute_mat_sqrt(A):
return V@(np.sqrt(B)[...,None]*V.transpose(0,2,1))

class compute_consensus_default:
def __init__(self, logsumexp = None, check_coeffs = False):
self._logsumexp = logsumexp if logsumexp is not None else logsumexp_scp
def __init__(self, check_coeffs = False):
if check_coeffs:
self.check_coeffs = self._check_coeffs
else:
Expand All @@ -543,7 +542,7 @@ def __init__(self, logsumexp = None, check_coeffs = False):
def __call__(self, energy, x, alpha):
weights = - alpha * energy
coeff_expan = tuple([Ellipsis] + [None for i in range(x.ndim-2)])
coeffs = np.exp(weights - self._logsumexp(weights, axis=-1, keepdims=True))[coeff_expan]
coeffs = np.exp(weights - logsumexp_scp(weights, axis=-1, keepdims=True))[coeff_expan]
self.check_coeffs(coeffs)
return (x * coeffs).sum(axis=1, keepdims=True), energy

Expand Down Expand Up @@ -830,7 +829,7 @@ def set_noise(self, noise) -> None:
elif callable(noise):
self.noise_callable = noise
else:
raise ValueError('Invalid noise model! Choose from "isotropic", "anisotropic", "sampling", "covariance", or a callable.')
raise ValueError('Invalid noise model: ' +str(noise) + '! Choose from "isotropic", "anisotropic", "sampling", "covariance", or a callable.')

def noise(self, ) -> ArrayLike:
"""
Expand All @@ -857,7 +856,7 @@ def update_covariance(self,) -> None:
"""
weights = - self.alpha * self.energy
coeffs = np.exp(weights - logsumexp(weights, axis=(-1,), keepdims=True))
coeffs = np.exp(weights - logsumexp_scp(weights, axis=(-1,), keepdims=True))

D = self.drift[...,None] * self.drift[...,None,:]
D = np.sum(D * coeffs[..., None, None], axis = -3)
Expand Down Expand Up @@ -891,7 +890,7 @@ def reset(self,):
self.t = 0.

def eval_f(self, x):
self.num_f_eval[self.active_runs_idx] += x.shape[-2] # update number of function evaluations
self.num_f_eval[self.active_runs_idx] += x.shape[1] # update number of function evaluations
return self.f(x)

def print_cur_state(self,):
Expand Down

0 comments on commit 553897c

Please sign in to comment.