Skip to content

Commit

Permalink
Fixed ESS scheduler and eta default value
Browse files Browse the repository at this point in the history
  • Loading branch information
TimRoith committed Feb 29, 2024
1 parent 6250072 commit d3e534e
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 86 deletions.
20 changes: 12 additions & 8 deletions cbx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ class param_update():
maximum : float
The maximum value of the parameter. The default is 1e5.
"""
def __init__(self,
name: str ='alpha',
maximum: float = 1e5,
minimum: float = 1e-5):
def __init__(
self,
name: str ='alpha',
maximum: float = 1e5,
minimum: float = 1e-5
):
self.name = name
self.maximum = maximum
self.minimum = minimum
Expand Down Expand Up @@ -144,7 +146,7 @@ class effective_sample_size(param_update):
"""
def __init__(self, name = 'alpha', eta=1.0, maximum=1e5, solve_max_it = 15):
def __init__(self, name = 'alpha', eta=.5, maximum=1e5, solve_max_it = 15):
super().__init__(name = name, maximum=maximum)
if self.name != 'alpha':
warnings.warn('effective_number scheduler only works for alpha parameter! You specified name = {}!'.format(self.name), stacklevel=2)
Expand Down Expand Up @@ -174,7 +176,7 @@ def __call__(self, alpha):
denom = logsumexp(-2 * alpha[:, None] * self.energy, axis=-1)
return np.exp(2 * nom - denom) - self.eta * self.N

def bisection_solve(f, low, high, max_it = 15, thresh = 1e-2):
def bisection_solve(f, low, high, max_it = 100, thresh = 1e-2, verbosity=0):
r"""simple bisection optimization to solve for roots
Parameters
Expand Down Expand Up @@ -202,10 +204,12 @@ def bisection_solve(f, low, high, max_it = 15, thresh = 1e-2):
gtzero = np.where(fx[idx] > 0)[0]
ltzero = np.where(fx[idx] < 0)[0]
# update low and high
high[idx[gtzero]] = x[idx[gtzero]]
low[idx[ltzero]] = x[idx[ltzero]]
low[idx[gtzero]] = x[idx[gtzero]]
high[idx[ltzero]] = x[idx[ltzero]]
# update running idx and iteration
idx = np.where(np.abs(fx) > thresh)[0]
it += 1
term = (it > max_it) | (len(idx) == 0)
if verbosity > 0:
print('Finishing after ' + str(it) + ' Iterations')
return x
Loading

0 comments on commit d3e534e

Please sign in to comment.