Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to linalg matrix inputs #27

Merged
merged 2 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "thermox"
version = "0.0.1"
description = "OU Processes and Linear Algebra with JAX"
description = "Exact OU processes with JAX"
readme = "README.md"
requires-python =">=3.9"
license = {text = "Apache-2.0"}
Expand Down
22 changes: 11 additions & 11 deletions thermox/linalg.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import jax
import jax.numpy as jnp
from thermox.sampler import sample, sample_identity_diffusion
from jax.lax import fori_loop
from jax import Array
from jax import Array, random
from thermox.sampler import sample, sample_identity_diffusion
from thermox.utils import ProcessedDriftMatrix


def solve(
A,
A: Array | ProcessedDriftMatrix,
b,
num_samples: int = 10000,
dt: float = 1.0,
Expand Down Expand Up @@ -34,15 +34,15 @@ def solve(
Approximate solution, x, of the linear system.
"""
if key is None:
key = jax.random.PRNGKey(0)
key = random.PRNGKey(0)
ts = jnp.arange(burnin, burnin + num_samples) * dt
x0 = jnp.zeros_like(b)
samples = sample_identity_diffusion(key, ts, x0, A, jnp.linalg.solve(A, b))
return jnp.mean(samples, axis=0)


def inv(
A,
A: Array,
num_samples: int = 10000,
dt: float = 1.0,
burnin: int = 0,
Expand All @@ -65,16 +65,16 @@ def inv(
Approximate inverse of A.
"""
if key is None:
key = jax.random.PRNGKey(0)
key = random.PRNGKey(0)
ts = jnp.arange(burnin, burnin + num_samples) * dt
b = jnp.zeros(A.shape[0])
x0 = jnp.zeros_like(b)
samples = sample(key, ts, x0, A, b, 2 * jnp.eye(A.shape[0]))
samples = sample_identity_diffusion(key, ts, x0, A, b, 2 * jnp.eye(A.shape[0]))
return jnp.cov(samples.T)
SamDuffield marked this conversation as resolved.
Show resolved Hide resolved


def expnegm(
A,
A: Array,
num_samples: int = 10000,
dt: float = 1.0,
burnin: int = 0,
Expand All @@ -100,7 +100,7 @@ def expnegm(
Approximate negative matrix exponential, exp(-A).
"""
if key is None:
key = jax.random.PRNGKey(0)
key = random.PRNGKey(0)

A_shifted = (A + alpha * jnp.eye(A.shape[0])) / dt
B = A_shifted + A_shifted.T
Expand All @@ -113,7 +113,7 @@ def expnegm(


def expm(
A,
A: Array,
num_samples: int = 10000,
dt: float = 1.0,
burnin: int = 0,
Expand Down
Loading