-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add reproducible code for matrix exponential simulations #37
Changes from 6 commits
4c9947c
9e0690d
111036f
96e2c5d
a55ed08
a36cc1a
7c79170
3b33c7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.pkl |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from jax import random, numpy as jnp | ||
from scipy.stats import ortho_group | ||
|
||
|
||
def wishart(d: int, key: random.PRNGKey) -> jnp.ndarray: | ||
n = 2 * d # degrees of freedom | ||
G = random.normal(key, shape=(d, n)) | ||
A_wishart = (G @ G.T) / n | ||
return A_wishart | ||
|
||
|
||
def orthogonal(d: int, _) -> jnp.ndarray: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why the second There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unified signature with |
||
return ortho_group.rvs(d) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import pickle | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import argparse | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--save_dir", type=str) | ||
args = parser.parse_args() | ||
|
||
|
||
matrix_type = args.save_dir.split("/")[-1].split("_")[1].split(".")[0] | ||
|
||
|
||
# use latex for plots | ||
plt.rc("text", usetex=True) | ||
# set font | ||
plt.rc("font", family="serif") | ||
# set font size | ||
plt.rcParams.update({"font.size": 10}) | ||
|
||
colors = [ | ||
plt.cm.viridis(0.2), | ||
plt.cm.viridis(0.4), | ||
plt.cm.viridis(0.6), | ||
plt.cm.viridis(0.8), | ||
] | ||
|
||
|
||
results = pickle.load(open(args.save_dir, "rb")) | ||
dt = results["dt"] | ||
NT = results["ERR_abs"].shape[-1] | ||
D = results["D"] | ||
e0_abs = 8.0 if matrix_type == "wishart" else 19.0 | ||
ylabel_abs = ( | ||
r"$|| \bar{C} - \exp(-A)||_F$" | ||
if matrix_type == "wishart" | ||
else r"$|| \bar{C} - \exp(-M)||_F$" | ||
) | ||
e0_rel = 0.9 | ||
ylabel_rel = ( | ||
r"$\frac{|| \bar{C} - \exp(-A)||_F}{||\exp(-A)||_F}$" | ||
if matrix_type == "wishart" | ||
else r"$\frac{|| \bar{C} - \exp(-M)||_F}{||\exp(-M)||_F}$" | ||
) | ||
fig_label = "(A)" if matrix_type == "wishart" else "(B)" | ||
|
||
|
||
def plot(ERR, ylabel, e0, save_path, d=False, d_squared=False, fig_label=None): | ||
T = np.arange(NT) * dt | ||
ERR_mean = ERR.mean(axis=0) | ||
|
||
# find time where error crosses threshold | ||
TC = np.zeros(len(D)) | ||
for i in range(len(D)): | ||
TC[i] = np.min(T[10:][ERR_mean[i, 10:] < e0]) | ||
|
||
plt.figure(figsize=(7, 4.5)) | ||
|
||
if fig_label is not None: | ||
plt.gcf().text(0.02, 0.93, fig_label, fontsize=22) | ||
|
||
for i in range(len(D)): | ||
plt.plot(T, ERR_mean[i], color=colors[i]) | ||
|
||
# Add error bars | ||
for i in range(len(D)): | ||
plt.fill_between( | ||
T, | ||
ERR_mean[i] - ERR[:, i].std(axis=0), | ||
ERR_mean[i] + ERR[:, i].std(axis=0), | ||
color=colors[i], | ||
alpha=0.3, | ||
zorder=0, | ||
) | ||
|
||
plt.loglog() | ||
plt.legend(["d = " + str(D[i]) for i in range(len(D))], loc="upper right") | ||
plt.xlabel(r"Time ($\mu$s)", fontsize=18) | ||
plt.ylabel(ylabel, fontsize=18) | ||
|
||
# show threshold as horizontal line | ||
plt.axhline(e0, color="k", linestyle="--") | ||
# show crossing times as vertical lines | ||
for i in range(len(D)): | ||
plt.axvline(TC[i], color=colors[i], linestyle="--") | ||
|
||
plt.xlim(30, T[-1]) | ||
|
||
# inset plot showing crossing time as a function of dimension | ||
ax = plt.axes([0.17, 0.22, 0.3, 0.35]) | ||
ax.tick_params(axis="y", direction="in", pad=-22) | ||
ax.tick_params(axis="x", direction="in", pad=-15) | ||
|
||
for i in range(len(D)): | ||
ax.scatter(D[i], TC[i], color=colors[i], zorder=10) | ||
|
||
ts = np.array([10, 2000]) | ||
|
||
if d: | ||
plt.plot(ts, 100 * ts, color="black", linestyle="--") | ||
plt.text(600, 8e4, s=r"$t_C = d$", rotation=25) | ||
|
||
if d_squared: | ||
plt.plot(ts, 0.3 * ts**2, color="black", linestyle="--") | ||
plt.text(550, 1.7e5, s=r"$t_C = d^2$", rotation=25) | ||
|
||
plt.plot(D, TC, color="black", zorder=0) | ||
plt.xlim(20, 1500) | ||
|
||
plt.loglog() | ||
plt.xlabel(r"$d$", fontsize=15) | ||
plt.ylabel(r"$t_C$", fontsize=15) | ||
plt.minorticks_off() | ||
|
||
plt.tight_layout() | ||
plt.savefig(save_path, dpi=300) | ||
plt.show() | ||
|
||
|
||
plot( | ||
results["ERR_abs"], | ||
ylabel_abs, | ||
e0_abs, | ||
f"examples/matrix_exponentials/{matrix_type}_abs.pdf", | ||
d_squared=True, | ||
fig_label=fig_label, | ||
) | ||
|
||
|
||
plot( | ||
results["ERR_rel"], | ||
ylabel_rel, | ||
e0_rel, | ||
f"examples/matrix_exponentials/{matrix_type}_rel.pdf", | ||
d=True, | ||
fig_label=fig_label, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from jax import random, jit, config, numpy as jnp | ||
from jax.scipy.linalg import expm | ||
from jax.lax import scan | ||
import numpy as np | ||
import argparse | ||
from tqdm import tqdm | ||
import thermox | ||
import pickle | ||
|
||
from examples.matrix_exponentials import matrix_generation | ||
|
||
# Set the precision of the computation | ||
config.update("jax_enable_x64", True) | ||
|
||
np.random.seed(42) | ||
|
||
|
||
# Load n_repeats, matrix_type and alpha from the command line | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--n_repeats", type=int, default=1) | ||
parser.add_argument("--matrix_type", type=str, default="wishart") | ||
parser.add_argument("--alpha", type=float, default=0.0) | ||
args = parser.parse_args() | ||
|
||
|
||
sample = jit(thermox.sample) | ||
get_matrix = getattr(matrix_generation, args.matrix_type) | ||
|
||
|
||
NT = 10000 | ||
dt = 12 | ||
ts = jnp.arange(NT) * dt | ||
N_burn = 0 | ||
keys = random.split(random.PRNGKey(42), args.n_repeats) | ||
|
||
SamDuffield marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
gamma = 1 | ||
beta = 1 | ||
alpha = args.alpha | ||
|
||
|
||
@jit | ||
def samps_to_autocovs_errs(samps, true_exp): | ||
def body_func(prev_mat, n): | ||
new_mat = prev_mat * n / (n + 1) + jnp.outer(samps[n], samps[n - 1]) / (n + 1) | ||
err = jnp.linalg.norm(new_mat * jnp.exp(alpha) - true_exp) | ||
return new_mat, err | ||
|
||
return scan( | ||
body_func, | ||
jnp.zeros((samps.shape[1], samps.shape[1])), | ||
jnp.arange(1, samps.shape[0]), | ||
)[1] | ||
|
||
|
||
D = [64, 128, 256, 512] | ||
ERR_abs = np.zeros((args.n_repeats, len(D), NT)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mixing uppercase and lowercase not very clean There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed |
||
ERR_rel = np.zeros_like(ERR_abs) | ||
|
||
for repeat in tqdm(range(args.n_repeats)): | ||
key = keys[repeat] | ||
for i in range(len(D)): | ||
d = D[i] | ||
print(f"Repeat {repeat}/{args.n_repeats}, \t D = {d}") | ||
|
||
A = get_matrix(d, key) | ||
|
||
exact_exp_min_A = expm(-A) | ||
|
||
A_shifted = (A + alpha * jnp.eye(A.shape[0])) / dt | ||
B = A_shifted + A_shifted.T | ||
|
||
A_shifted_lambda_min = jnp.min(jnp.linalg.eig(A_shifted / gamma)[0].real) | ||
print("A Eig min: ", A_shifted_lambda_min) | ||
|
||
D_lambda_min = jnp.min(jnp.linalg.eig(B / (gamma * beta))[0].real) | ||
print("D Eig min: ", D_lambda_min) | ||
|
||
x0 = np.zeros(d) | ||
X = sample( | ||
key, | ||
ts, | ||
x0, | ||
A_shifted / gamma, | ||
np.zeros(d), | ||
B / (gamma * beta), | ||
) | ||
|
||
err_abs = samps_to_autocovs_errs(X, exact_exp_min_A) | ||
|
||
ERR_abs[repeat, i, 1:] = err_abs | ||
ERR_rel[repeat, i, 1:] = err_abs / jnp.linalg.norm(exact_exp_min_A) | ||
|
||
with open( | ||
f"examples/matrix_exponentials/results_{args.matrix_type}.pkl", "wb" | ||
) as f: | ||
pickle.dump( | ||
{ | ||
"D": D, | ||
"dt": dt, | ||
"alpha": alpha, | ||
"ERR_abs": ERR_abs, | ||
"ERR_rel": ERR_rel, | ||
}, | ||
f, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments seem a bit random in the scripts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this one is ok haha have cleaned up run.py