Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add associative scan #30

Merged
merged 19 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 251 additions & 0 deletions examples/associative_scan.ipynb

Large diffs are not rendered by default.

15 changes: 12 additions & 3 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,27 @@


def test_sample_array_input():
jax.config.update("jax_enable_x64", True)
key = jax.random.PRNGKey(0)
dim = 2
dt = 0.1
ts = jnp.arange(0, 10_000, dt)

A = jnp.array([[3, 2], [2, 4.0]])
b, x0 = jnp.zeros(dim), jnp.zeros(dim)
# Add some noise to the time points to make the timesteps different
ts += jax.random.uniform(key, (ts.shape[0],)) * dt
ts = ts.sort()

A = jnp.array([[3, 2.5], [2, 4.0]])
b = jax.random.normal(jax.random.PRNGKey(1), (dim,))
x0 = jax.random.normal(jax.random.PRNGKey(2), (dim,))
D = 2 * jnp.eye(dim)

samples = thermox.sample(key, ts, x0, A, b, D)
samples = thermox.sample(key, ts, x0, A, b, D, associative_scan=False)

samp_cov = jnp.cov(samples.T)
samp_mean = jnp.mean(samples.T, axis=1)
assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1)
assert jnp.allclose(samp_mean, b, atol=1e-1)

samples_as = thermox.sample(key, ts, x0, A, b, D, associative_scan=True)
assert jnp.allclose(samples, samples_as, atol=1e-6)
116 changes: 47 additions & 69 deletions thermox/prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,74 +8,7 @@
ProcessedDriftMatrix,
ProcessedDiffusionMatrix,
)


def log_prob_identity_diffusion(
ts: Array,
xs: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
) -> float:
"""Calculates log probability of samples from the Ornstein-Uhlenbeck process,
defined as:

dx = - A * (x - b) dt + dW

by using exact diagonalization.

Assumes x(t_0) is given deterministically.

Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2).

Args:
ts: Times at which samples are collected. Includes time for x0.
xs: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
b: Drift displacement vector.
Returns:
Scalar log probability of given xs.
"""
if isinstance(A, Array):
A = preprocess_drift_matrix(A)

def expm_vp(v, dt):
out = A.eigvecs_inv @ v
out = jnp.exp(-A.eigvals * dt) * out
out = A.eigvecs @ out
return out.real

def transition_mean(y, dt):
return b + expm_vp(y - b, dt)

def transition_cov_sqrt_inv_vp(v, dt):
diag = ((1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)) ** 0.5
diag = jnp.where(diag < 1e-20, 1e-20, diag)
out = A.sym_eigvecs.T @ v
out = out / diag
return out.real

def transition_cov_log_det(dt):
diag = (1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)
diag = jnp.where(diag < 1e-20, 1e-20, diag)
return jnp.sum(jnp.log(diag))

def logpt(yt, y0, dt):
mean = transition_mean(y0, dt)
diff_val = transition_cov_sqrt_inv_vp(yt - mean, dt)
return (
-jnp.dot(diff_val, diff_val) / 2
- transition_cov_log_det(dt) / 2
- jnp.log(2 * jnp.pi) * (yt.shape[0] / 2)
)

log_prob_val = fori_loop(
1,
len(ts),
lambda i, val: val + logpt(xs[i], xs[i - 1], ts[i] - ts[i - 1]),
0.0,
)

return log_prob_val.real
from thermox.sampler import expm_vp


def log_prob(
Expand Down Expand Up @@ -105,7 +38,7 @@ def log_prob(
ts: Times at which samples are collected. Includes time for x0.
xs: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note : If a thermox.ProcessedDriftMatrix instance is used as input,
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
Expand All @@ -122,3 +55,48 @@ def log_prob(

D_sqrt_inv_log_det = jnp.log(jnp.linalg.det(D.sqrt_inv))
return log_prob_ys + D_sqrt_inv_log_det * (len(ts) - 1)


def transition_cov_sqrt_inv_vp(A, v, dt):
diag = ((1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)) ** 0.5
diag = jnp.where(diag < 1e-20, 1e-20, diag)
out = A.sym_eigvecs.T @ v
out = out / diag
return out.real


def transition_cov_log_det(A, dt):
diag = (1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)
diag = jnp.where(diag < 1e-20, 1e-20, diag)
return jnp.sum(jnp.log(diag))


def log_prob_identity_diffusion(
ts: Array,
xs: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
) -> float:
if isinstance(A, Array):
A = preprocess_drift_matrix(A)

def transition_mean(y, dt):
return b + expm_vp(A, y - b, dt)

def logpt(yt, y0, dt):
mean = transition_mean(y0, dt)
diff_val = transition_cov_sqrt_inv_vp(A, yt - mean, dt)
return (
-jnp.dot(diff_val, diff_val) / 2
- transition_cov_log_det(A, dt) / 2
- jnp.log(2 * jnp.pi) * (yt.shape[0] / 2)
)

log_prob_val = fori_loop(
1,
len(ts),
lambda i, val: val + logpt(xs[i], xs[i - 1], ts[i] - ts[i - 1]),
0.0,
)

return log_prob_val.real
151 changes: 95 additions & 56 deletions thermox/sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
import jax
import jax.numpy as jnp
from jax.lax import scan
from jax import Array

from thermox.utils import (
Expand All @@ -11,108 +11,147 @@
)


def sample_identity_diffusion(
def sample(
key: Array,
ts: Array,
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
D: Array | ProcessedDiffusionMatrix,
associative_scan: bool = True,
) -> Array:
"""Collects samples from the Ornstein-Uhlenbeck process, defined as:

dx = - A * (x - b) dt + dW
dx = - A * (x - b) dt + sqrt(D) dW

by using exact diagonalization.

Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2)
Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2),
where T=len(ts).

If associative_scan=True then jax.lax.associative_scan is used which will run in
time O((T/p + log(T)) * d^2) on a GPU/TPU with p cores, still with
O(d^3) preprocessing.

By default, this function does the preprocessing on A and D before the evaluation.
However, the preprocessing can be done externally using thermox.preprocess
the output of which can be used as A and D here, this will skip the preprocessing.

Args:
key: Jax PRNGKey.
ts: Times at which samples are collected. Includes time for x0.
x0: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
associative_scan: If True, uses jax.lax.associative_scan.

Returns:
Array-like, desired samples.
shape: (len(ts), ) + x0.shape
"""
A_y, D = handle_matrix_inputs(A, D)

y0 = D.sqrt_inv @ x0
b_y = D.sqrt_inv @ b
ys = sample_identity_diffusion(key, ts, y0, A_y, b_y, associative_scan)
return jax.vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt, ys)


def sample_identity_diffusion(
key: Array,
ts: Array,
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
associative_scan: bool = True,
) -> Array:
if associative_scan:
return _sample_identity_diffusion_associative_scan(key, ts, x0, A, b)
else:
return _sample_identity_diffusion_scan(key, ts, x0, A, b)


def expm_vp(A, v, dt):
out = A.eigvecs_inv @ v
out = jnp.exp(-A.eigvals * dt) * out
out = A.eigvecs @ out
return out.real


def transition_cov_sqrt_vp(A, v, dt):
diag = ((1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)) ** 0.5
out = diag * v
out = A.sym_eigvecs @ out
return out.real


def _sample_identity_diffusion_scan(
key: Array,
ts: Array,
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
) -> Array:
if isinstance(A, Array):
A = preprocess_drift_matrix(A)

def expm_vp(v, dt):
out = A.eigvecs_inv @ v
out = jnp.exp(-A.eigvals * dt) * out
out = A.eigvecs @ out
return out.real

def transition_mean(x, dt):
return b + expm_vp(x - b, dt)
return b + expm_vp(A, x - b, dt)

def transition_cov_sqrt_vp(v, dt):
diag = ((1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)) ** 0.5
out = diag * v
out = A.sym_eigvecs @ out
return out.real
def next_x(x, dt, rv):
return transition_mean(x, dt) + transition_cov_sqrt_vp(A, rv, dt)

def next_x(x, dt, tkey):
randv = jax.random.normal(tkey, shape=x.shape)
return transition_mean(x, dt) + transition_cov_sqrt_vp(randv, dt)

def scan_body(x_and_key, dt):
x, rk = x_and_key
rk, rk_use = jax.random.split(rk)
x = next_x(x, dt, rk_use)
return (x, rk), x
def scan_body(carry, dt_and_rv):
x = carry
dt, rv = dt_and_rv
new_x = next_x(x, dt, rv)
return new_x, new_x

dts = jnp.diff(ts)
gauss_samps = jax.random.normal(key, (len(dts),) + x0.shape)

# Stack dts and gauss_samps along a new axis
dt_and_rv = (dts, gauss_samps)

xs = scan(scan_body, (x0, key), dts)[1]
_, xs = jax.lax.scan(scan_body, x0, dt_and_rv)
xs = jnp.concatenate([jnp.expand_dims(x0, axis=0), xs], axis=0)
return xs


def sample(
def _sample_identity_diffusion_associative_scan(
key: Array,
ts: Array,
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
D: Array | ProcessedDiffusionMatrix,
) -> Array:
"""Collects samples from the Ornstein-Uhlenbeck process, defined as:

dx = - A * (x - b) dt + sqrt(D) dW
if isinstance(A, Array):
A = preprocess_drift_matrix(A)

by using exact diagonalization.
dts = jnp.diff(ts)

Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2),
where T=len(ts).
# transition_mean(x, dt) = b + expm_vp(A, x - b, dt)

By default, this function does the preprocessing on A and D before the evaluation.
However, the preprocessing can be done externally using thermox.preprocess
the output of which can be used as A and D here, this will skip the preprocessing.
gauss_samps = jax.random.normal(key, (len(dts),) + x0.shape)
noise_terms = jax.vmap(lambda v, dt: transition_cov_sqrt_vp(A, v, dt))(
gauss_samps, dts
)

Args:
key: Jax PRNGKey.
ts: Times at which samples are collected. Includes time for x0.
x0: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note : If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
@partial(jax.vmap, in_axes=(0, 0))
def binary_associative_operator(elem_a, elem_b):
t_a, x_a = elem_a
t_b, x_b = elem_b
return t_a + t_b, expm_vp(A, x_a, t_b) + x_b

Returns:
Array-like, desired samples.
shape: (len(ts), ) + x0.shape
"""
A_y, D = handle_matrix_inputs(A, D)
scan_times = jnp.concatenate([ts[:1], dts], dtype=float) # [t0, dt1, dt2, ...]
scan_input_values = jnp.concatenate(
[x0[None] - b, noise_terms], axis=0
) # Shift input by b
scan_elems = (scan_times, scan_input_values)

y0 = D.sqrt_inv @ x0
b_y = D.sqrt_inv @ b
ys = sample_identity_diffusion(key, ts, y0, A_y, b_y)
return jax.vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt, ys)
scan_output = jax.lax.associative_scan(binary_associative_operator, scan_elems)
return scan_output[1] + b # Shift back by b