diff --git a/thermox/linalg.py b/thermox/linalg.py index dba4a72..0a25117 100644 --- a/thermox/linalg.py +++ b/thermox/linalg.py @@ -11,7 +11,7 @@ def solve( num_samples: int = 10000, dt: float = 1.0, burnin: int = 0, - key: Array = None, + key: Array | None = None, associative_scan: bool = True, ) -> Array: """ @@ -37,11 +37,12 @@ def solve( """ if key is None: key = random.PRNGKey(0) - ts = jnp.arange(burnin, burnin + num_samples) * dt + ts = jnp.arange(burnin, burnin + num_samples + 1) * dt + ts = jnp.concatenate([jnp.array([0]), ts]) x0 = jnp.zeros_like(b) samples = sample_identity_diffusion( key, ts, x0, A, jnp.linalg.solve(A, b), associative_scan - ) + )[1:] return jnp.mean(samples, axis=0) @@ -50,7 +51,7 @@ def inv( num_samples: int = 10000, dt: float = 1.0, burnin: int = 0, - key: Array = None, + key: Array | None = None, associative_scan: bool = True, ) -> Array: """ @@ -72,10 +73,11 @@ def inv( """ if key is None: key = random.PRNGKey(0) - ts = jnp.arange(burnin, burnin + num_samples) * dt + ts = jnp.arange(burnin, burnin + num_samples + 1) * dt + ts = jnp.concatenate([jnp.array([0]), ts]) b = jnp.zeros(A.shape[0]) x0 = jnp.zeros_like(b) - samples = sample(key, ts, x0, A, b, 2 * jnp.eye(A.shape[0]), associative_scan) + samples = sample(key, ts, x0, A, b, 2 * jnp.eye(A.shape[0]), associative_scan)[1:] return jnp.cov(samples.T) @@ -84,7 +86,7 @@ def expnegm( num_samples: int = 10000, dt: float = 1.0, burnin: int = 0, - key: Array = None, + key: Array | None = None, alpha: float = 0.0, associative_scan: bool = True, ) -> Array: @@ -113,10 +115,11 @@ def expnegm( A_shifted = (A + alpha * jnp.eye(A.shape[0])) / dt B = A_shifted + A_shifted.T - ts = jnp.arange(burnin, burnin + num_samples) * dt + ts = jnp.arange(burnin, burnin + num_samples + 1) * dt + ts = jnp.concatenate([jnp.array([0]), ts]) b = jnp.zeros(A.shape[0]) x0 = jnp.zeros_like(b) - samples = sample(key, ts, x0, A_shifted, b, B, associative_scan) + samples = sample(key, ts, x0, A_shifted, b, B, associative_scan)[1:] return autocovariance(samples) * jnp.exp(alpha) @@ -125,7 +128,7 @@ def expm( num_samples: int = 10000, dt: float = 1.0, burnin: int = 0, - key: Array = None, + key: Array | None = None, alpha: float = 1.0, associative_scan: bool = True, ) -> Array: diff --git a/thermox/sampler.py b/thermox/sampler.py index 53d0018..7b83610 100644 --- a/thermox/sampler.py +++ b/thermox/sampler.py @@ -26,7 +26,7 @@ def sample( by using exact diagonalization. - Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2), + Preprocessing (diagonalization) costs O(d^3) and sampling costs O(T * d^2), where T=len(ts). If associative_scan=True then jax.lax.associative_scan is used which will run in