Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reproducible code for matrix exponential simulations #37

Merged
merged 8 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/matrix_exponentials/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pkl
13 changes: 13 additions & 0 deletions examples/matrix_exponentials/matrix_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from jax import random, numpy as jnp
from scipy.stats import ortho_group


def wishart(d: int, key: random.PRNGKey) -> jnp.ndarray:
n = 2 * d # degrees of freedom
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments seem a bit random in the scripts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this one is ok haha have cleaned up run.py

G = random.normal(key, shape=(d, n))
A_wishart = (G @ G.T) / n
return A_wishart


def orthogonal(d: int, _) -> jnp.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the second _ argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unified signature with wishart but ortho_group is a scipy function that doesn't use jax random keys

return ortho_group.rvs(d)
Binary file added examples/matrix_exponentials/orthogonal_abs.pdf
Binary file not shown.
Binary file added examples/matrix_exponentials/orthogonal_rel.pdf
Binary file not shown.
138 changes: 138 additions & 0 deletions examples/matrix_exponentials/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import pickle
import numpy as np
import matplotlib.pyplot as plt
import argparse


parser = argparse.ArgumentParser()
parser.add_argument("--save_dir", type=str)
args = parser.parse_args()


matrix_type = args.save_dir.split("/")[-1].split("_")[1].split(".")[0]


# use latex for plots
plt.rc("text", usetex=True)
# set font
plt.rc("font", family="serif")
# set font size
plt.rcParams.update({"font.size": 10})

colors = [
plt.cm.viridis(0.2),
plt.cm.viridis(0.4),
plt.cm.viridis(0.6),
plt.cm.viridis(0.8),
]


results = pickle.load(open(args.save_dir, "rb"))
dt = results["dt"]
NT = results["ERR_abs"].shape[-1]
D = results["D"]
e0_abs = 8.0 if matrix_type == "wishart" else 19.0
ylabel_abs = (
r"$|| \bar{C} - \exp(-A)||_F$"
if matrix_type == "wishart"
else r"$|| \bar{C} - \exp(-M)||_F$"
)
e0_rel = 0.9
ylabel_rel = (
r"$\frac{|| \bar{C} - \exp(-A)||_F}{||\exp(-A)||_F}$"
if matrix_type == "wishart"
else r"$\frac{|| \bar{C} - \exp(-M)||_F}{||\exp(-M)||_F}$"
)
fig_label = "(A)" if matrix_type == "wishart" else "(B)"


def plot(ERR, ylabel, e0, save_path, d=False, d_squared=False, fig_label=None):
T = np.arange(NT) * dt
ERR_mean = ERR.mean(axis=0)

# find time where error crosses threshold
TC = np.zeros(len(D))
for i in range(len(D)):
TC[i] = np.min(T[10:][ERR_mean[i, 10:] < e0])

plt.figure(figsize=(7, 4.5))

if fig_label is not None:
plt.gcf().text(0.02, 0.93, fig_label, fontsize=22)

for i in range(len(D)):
plt.plot(T, ERR_mean[i], color=colors[i])

# Add error bars
for i in range(len(D)):
plt.fill_between(
T,
ERR_mean[i] - ERR[:, i].std(axis=0),
ERR_mean[i] + ERR[:, i].std(axis=0),
color=colors[i],
alpha=0.3,
zorder=0,
)

plt.loglog()
plt.legend(["d = " + str(D[i]) for i in range(len(D))], loc="upper right")
plt.xlabel(r"Time ($\mu$s)", fontsize=18)
plt.ylabel(ylabel, fontsize=18)

# show threshold as horizontal line
plt.axhline(e0, color="k", linestyle="--")
# show crossing times as vertical lines
for i in range(len(D)):
plt.axvline(TC[i], color=colors[i], linestyle="--")

plt.xlim(30, T[-1])

# inset plot showing crossing time as a function of dimension
ax = plt.axes([0.17, 0.22, 0.3, 0.35])
ax.tick_params(axis="y", direction="in", pad=-22)
ax.tick_params(axis="x", direction="in", pad=-15)

for i in range(len(D)):
ax.scatter(D[i], TC[i], color=colors[i], zorder=10)

ts = np.array([10, 2000])

if d:
plt.plot(ts, 100 * ts, color="black", linestyle="--")
plt.text(600, 8e4, s=r"$t_C = d$", rotation=25)

if d_squared:
plt.plot(ts, 0.3 * ts**2, color="black", linestyle="--")
plt.text(550, 1.7e5, s=r"$t_C = d^2$", rotation=25)

plt.plot(D, TC, color="black", zorder=0)
plt.xlim(20, 1500)

plt.loglog()
plt.xlabel(r"$d$", fontsize=15)
plt.ylabel(r"$t_C$", fontsize=15)
plt.minorticks_off()

plt.tight_layout()
plt.savefig(save_path, dpi=300)
plt.show()


plot(
results["ERR_abs"],
ylabel_abs,
e0_abs,
f"examples/matrix_exponentials/{matrix_type}_abs.pdf",
d_squared=True,
fig_label=fig_label,
)


plot(
results["ERR_rel"],
ylabel_rel,
e0_rel,
f"examples/matrix_exponentials/{matrix_type}_rel.pdf",
d=True,
fig_label=fig_label,
)
106 changes: 106 additions & 0 deletions examples/matrix_exponentials/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from jax import random, jit, config, numpy as jnp
from jax.scipy.linalg import expm
from jax.lax import scan
import numpy as np
import argparse
from tqdm import tqdm
import thermox
import pickle

from examples.matrix_exponentials import matrix_generation

# Set the precision of the computation
config.update("jax_enable_x64", True)

np.random.seed(42)


# Load n_repeats, matrix_type and alpha from the command line
parser = argparse.ArgumentParser()
parser.add_argument("--n_repeats", type=int, default=1)
parser.add_argument("--matrix_type", type=str, default="wishart")
parser.add_argument("--alpha", type=float, default=0.0)
args = parser.parse_args()


sample = jit(thermox.sample)
get_matrix = getattr(matrix_generation, args.matrix_type)


NT = 10000
dt = 12
ts = jnp.arange(NT) * dt
N_burn = 0
keys = random.split(random.PRNGKey(42), args.n_repeats)

SamDuffield marked this conversation as resolved.
Show resolved Hide resolved

gamma = 1
beta = 1
alpha = args.alpha


@jit
def samps_to_autocovs_errs(samps, true_exp):
def body_func(prev_mat, n):
new_mat = prev_mat * n / (n + 1) + jnp.outer(samps[n], samps[n - 1]) / (n + 1)
err = jnp.linalg.norm(new_mat * jnp.exp(alpha) - true_exp)
return new_mat, err

return scan(
body_func,
jnp.zeros((samps.shape[1], samps.shape[1])),
jnp.arange(1, samps.shape[0]),
)[1]


D = [64, 128, 256, 512]
ERR_abs = np.zeros((args.n_repeats, len(D), NT))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mixing uppercase and lowercase not very clean

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed ERR to err, I think the rest is ok

ERR_rel = np.zeros_like(ERR_abs)

for repeat in tqdm(range(args.n_repeats)):
key = keys[repeat]
for i in range(len(D)):
d = D[i]
print(f"Repeat {repeat}/{args.n_repeats}, \t D = {d}")

A = get_matrix(d, key)

exact_exp_min_A = expm(-A)

A_shifted = (A + alpha * jnp.eye(A.shape[0])) / dt
B = A_shifted + A_shifted.T

A_shifted_lambda_min = jnp.min(jnp.linalg.eig(A_shifted / gamma)[0].real)
print("A Eig min: ", A_shifted_lambda_min)

D_lambda_min = jnp.min(jnp.linalg.eig(B / (gamma * beta))[0].real)
print("D Eig min: ", D_lambda_min)

x0 = np.zeros(d)
X = sample(
key,
ts,
x0,
A_shifted / gamma,
np.zeros(d),
B / (gamma * beta),
)

err_abs = samps_to_autocovs_errs(X, exact_exp_min_A)

ERR_abs[repeat, i, 1:] = err_abs
ERR_rel[repeat, i, 1:] = err_abs / jnp.linalg.norm(exact_exp_min_A)

with open(
f"examples/matrix_exponentials/results_{args.matrix_type}.pkl", "wb"
) as f:
pickle.dump(
{
"D": D,
"dt": dt,
"alpha": alpha,
"ERR_abs": ERR_abs,
"ERR_rel": ERR_rel,
},
f,
)
Binary file added examples/matrix_exponentials/wishart_abs.pdf
Binary file not shown.
Binary file added examples/matrix_exponentials/wishart_rel.pdf
Binary file not shown.
6 changes: 1 addition & 5 deletions thermox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,14 @@ def preprocess_drift_matrix(A: Array) -> ProcessedDriftMatrix:
"""

A_eigvals, A_eigvecs = eig(A + 0.0j)

A_eigvals = A_eigvals.real
A_eigvecs = A_eigvecs.real

A_eigvecs_inv = jnp.linalg.inv(A_eigvecs)

symA = 0.5 * (A + A.T)
symA_eigvals, symA_eigvecs = jnp.linalg.eigh(symA)

return ProcessedDriftMatrix(
A,
A_eigvals.real,
A_eigvals,
A_eigvecs,
A_eigvecs_inv,
symA_eigvals,
Expand Down