Skip to content

Commit

Permalink
add underscore to function name and
Browse files Browse the repository at this point in the history
remove docstringÒ
  • Loading branch information
KaelanDt committed Jun 4, 2024
1 parent 3d8460f commit dcaaadd
Showing 1 changed file with 2 additions and 26 deletions.
28 changes: 2 additions & 26 deletions thermox/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,14 @@
)


def sample_identity_diffusion(
def _sample_identity_diffusion(
key: Array,
ts: Array,
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
associative_scan: bool = True,
) -> Array:
"""Collects samples from the Ornstein-Uhlenbeck process, defined as:
dx = - A * (x - b) dt + dW
by using exact diagonalization.
Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2)
where T=len(ts).
If associative_scan = True then jax.lax.associative_scan is used, so will run in
time O(log(T) * d^2) on a GPU/TPU with O(T) cores.
Args:
key: Jax PRNGKey.
ts: Times at which samples are collected. Includes time for x0.
x0: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
b: Drift displacement vector.
associative_scan: If True, uses jax.lax.associative_scan.
Returns:
Array-like, desired samples.
shape: (len(ts), ) + x0.shape
"""
if associative_scan:
return _sample_identity_diffusion_associative_scan(key, ts, x0, A, b)
else:
Expand Down Expand Up @@ -186,5 +162,5 @@ def sample(

y0 = D.sqrt_inv @ x0
b_y = D.sqrt_inv @ b
ys = sample_identity_diffusion(key, ts, y0, A_y, b_y, associative_scan)
ys = _sample_identity_diffusion(key, ts, y0, A_y, b_y, associative_scan)
return jax.vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt, ys)

0 comments on commit dcaaadd

Please sign in to comment.