Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reproducible code for matrix exponential simulations #37

Merged
merged 8 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/matrix_exponentials/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pkl
13 changes: 13 additions & 0 deletions examples/matrix_exponentials/matrix_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from jax import random, numpy as jnp
from scipy.stats import ortho_group


def wishart(d: int, key: random.PRNGKey) -> jnp.ndarray:
n = 2 * d # degrees of freedom
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments seem a bit random in the scripts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this one is ok haha have cleaned up run.py

G = random.normal(key, shape=(d, n))
A_wishart = (G @ G.T) / n
return A_wishart


def orthogonal(d: int, _) -> jnp.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the second _ argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unified signature with wishart but ortho_group is a scipy function that doesn't use jax random keys

return ortho_group.rvs(d)
Binary file added examples/matrix_exponentials/orthogonal_abs.pdf
Binary file not shown.
Binary file added examples/matrix_exponentials/orthogonal_rel.pdf
Binary file not shown.
138 changes: 138 additions & 0 deletions examples/matrix_exponentials/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import pickle
import numpy as np
import matplotlib.pyplot as plt
import argparse


parser = argparse.ArgumentParser()
parser.add_argument("--save_dir", type=str)
args = parser.parse_args()


matrix_type = args.save_dir.split("/")[-1].split("_")[1].split(".")[0]


# use latex for plots
plt.rc("text", usetex=True)
# set font
plt.rc("font", family="serif")
# set font size
plt.rcParams.update({"font.size": 10})

colors = [
plt.cm.viridis(0.2),
plt.cm.viridis(0.4),
plt.cm.viridis(0.6),
plt.cm.viridis(0.8),
]


results = pickle.load(open(args.save_dir, "rb"))
dt = results["dt"]
NT = results["err_abs"].shape[-1]
D = results["D"]
e0_abs = 8.0 if matrix_type == "wishart" else 19.0
ylabel_abs = (
r"$|| \bar{C} - \exp(-A)||_F$"
if matrix_type == "wishart"
else r"$|| \bar{C} - \exp(-M)||_F$"
)
e0_rel = 0.9
ylabel_rel = (
r"$\frac{|| \bar{C} - \exp(-A)||_F}{||\exp(-A)||_F}$"
if matrix_type == "wishart"
else r"$\frac{|| \bar{C} - \exp(-M)||_F}{||\exp(-M)||_F}$"
)
fig_label = "(A)" if matrix_type == "wishart" else "(B)"


def plot(err, ylabel, e0, save_path, d=False, d_squared=False, fig_label=None):
T = np.arange(NT) * dt
err_mean = err.mean(axis=0)

# find time where error crosses threshold
TC = np.zeros(len(D))
for i in range(len(D)):
TC[i] = np.min(T[10:][err_mean[i, 10:] < e0])

plt.figure(figsize=(7, 4.5))

if fig_label is not None:
plt.gcf().text(0.02, 0.93, fig_label, fontsize=22)

for i in range(len(D)):
plt.plot(T, err_mean[i], color=colors[i])

# Add error bars
for i in range(len(D)):
plt.fill_between(
T,
err_mean[i] - err[:, i].std(axis=0),
err_mean[i] + err[:, i].std(axis=0),
color=colors[i],
alpha=0.3,
zorder=0,
)

plt.loglog()
plt.legend(["d = " + str(D[i]) for i in range(len(D))], loc="upper right")
plt.xlabel(r"Time ($\mu$s)", fontsize=18)
plt.ylabel(ylabel, fontsize=18)

# show threshold as horizontal line
plt.axhline(e0, color="k", linestyle="--")
# show crossing times as vertical lines
for i in range(len(D)):
plt.axvline(TC[i], color=colors[i], linestyle="--")

plt.xlim(30, T[-1])

# inset plot showing crossing time as a function of dimension
ax = plt.axes([0.17, 0.22, 0.3, 0.35])
ax.tick_params(axis="y", direction="in", pad=-22)
ax.tick_params(axis="x", direction="in", pad=-15)

for i in range(len(D)):
ax.scatter(D[i], TC[i], color=colors[i], zorder=10)

ts = np.array([10, 2000])

if d:
plt.plot(ts, 100 * ts, color="black", linestyle="--")
plt.text(600, 8e4, s=r"$t_C = d$", rotation=25)

if d_squared:
plt.plot(ts, 0.3 * ts**2, color="black", linestyle="--")
plt.text(550, 1.7e5, s=r"$t_C = d^2$", rotation=25)

plt.plot(D, TC, color="black", zorder=0)
plt.xlim(20, 1500)

plt.loglog()
plt.xlabel(r"$d$", fontsize=15)
plt.ylabel(r"$t_C$", fontsize=15)
plt.minorticks_off()

plt.tight_layout()
plt.savefig(save_path, dpi=300)
plt.show()


plot(
results["err_abs"],
ylabel_abs,
e0_abs,
f"examples/matrix_exponentials/{matrix_type}_abs.pdf",
d_squared=True,
fig_label=fig_label,
)


plot(
results["err_rel"],
ylabel_rel,
e0_rel,
f"examples/matrix_exponentials/{matrix_type}_rel.pdf",
d=True,
fig_label=fig_label,
)
114 changes: 114 additions & 0 deletions examples/matrix_exponentials/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from jax import random, jit, config, numpy as jnp
from jax.scipy.linalg import expm
from jax.lax import scan
import numpy as np
import argparse
from tqdm import tqdm
import thermox
import pickle

from examples.matrix_exponentials import matrix_generation

# Set the precision of the computation
config.update("jax_enable_x64", True)

# Set seed for orthogonal matrix generation
np.random.seed(42)

# Load n_repeats, matrix_type and alpha from the command line
parser = argparse.ArgumentParser()
parser.add_argument("--n_repeats", type=int, default=1)
parser.add_argument("--matrix_type", type=str, default="wishart")
parser.add_argument("--alpha", type=float, default=0.0)
args = parser.parse_args()
get_matrix = getattr(matrix_generation, args.matrix_type)
alpha = args.alpha

# Jit for speed (avoid recompilation)
sample = jit(thermox.sample)

# Hyperparameters shared across all experiments
NT = 10000
dt = 12
ts = jnp.arange(NT) * dt
N_burn = 0
keys = random.split(random.PRNGKey(42), args.n_repeats)
gamma = 1
beta = 1
D = [64, 128, 256, 512]


# Function to compute array of autocovariance errors from samples
@jit
def samps_to_autocovs_errs(samps, true_exp):
def body_func(prev_mat, n):
new_mat = prev_mat * n / (n + 1) + jnp.outer(samps[n], samps[n - 1]) / (n + 1)
err = jnp.linalg.norm(new_mat * jnp.exp(alpha) - true_exp)
return new_mat, err

return scan(
body_func,
jnp.zeros((samps.shape[1], samps.shape[1])),
jnp.arange(1, samps.shape[0]),
)[1]


# Initialize arrays to store errors
err_abs = np.zeros((args.n_repeats, len(D), NT))
err_rel = np.zeros_like(err_abs)

# Loop over repeats and dimensions
for repeat in tqdm(range(args.n_repeats)):
key = keys[repeat]
for i in range(len(D)):
d = D[i]
print(f"Repeat {repeat}/{args.n_repeats}, \t D = {d}")

A = get_matrix(d, key)
exact_exp_min_A = expm(-A)

# Shift and scale A and compute symmetrized B
A_shifted = (A + alpha * jnp.eye(A.shape[0])) / dt
B = A_shifted + A_shifted.T

# Print eigenvalues
A_shifted_lambda_min = jnp.min(jnp.linalg.eig(A_shifted / gamma)[0].real)
print("A Eig min: ", A_shifted_lambda_min)

D_lambda_min = jnp.min(jnp.linalg.eig(B / (gamma * beta))[0].real)
print("D Eig min: ", D_lambda_min)

# Initialize at zeros
x0 = np.zeros(d)

# Run the sampler
X = sample(
key,
ts,
x0,
A_shifted / gamma,
np.zeros(d),
B / (gamma * beta),
)

# Compute absolute error
err_abs = samps_to_autocovs_errs(X, exact_exp_min_A)
err_abs[repeat, i, 1:] = err_abs

# Compute relative error
err_rel[repeat, i, 1:] = err_abs / jnp.linalg.norm(exact_exp_min_A)

# Save results (overwrites after each repeat)
with open(
f"examples/matrix_exponentials/results_{args.matrix_type}.pkl", "wb"
) as f:
pickle.dump(
{
"D": D,
"dt": dt,
"alpha": alpha,
"err_abs": err_abs,
"err_rel": err_rel,
},
f,
)
Binary file added examples/matrix_exponentials/wishart_abs.pdf
Binary file not shown.
Binary file added examples/matrix_exponentials/wishart_rel.pdf
Binary file not shown.
6 changes: 1 addition & 5 deletions thermox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,14 @@ def preprocess_drift_matrix(A: Array) -> ProcessedDriftMatrix:
"""

A_eigvals, A_eigvecs = eig(A + 0.0j)

A_eigvals = A_eigvals.real
A_eigvecs = A_eigvecs.real

A_eigvecs_inv = jnp.linalg.inv(A_eigvecs)

symA = 0.5 * (A + A.T)
symA_eigvals, symA_eigvecs = jnp.linalg.eigh(symA)

return ProcessedDriftMatrix(
A,
A_eigvals.real,
A_eigvals,
A_eigvecs,
A_eigvecs_inv,
symA_eigvals,
Expand Down