Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add associative scan #30

Merged
merged 19 commits into from
Jun 5, 2024
Merged

Add associative scan #30

merged 19 commits into from
Jun 5, 2024

Conversation

SamDuffield
Copy link
Contributor

First attempt at using jax.lax.associative_scan #14 , but it's throwing a matmul contracting dimensions error and I'm not sure why.

Copy link

@AdrienCorenflos AdrienCorenflos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Associative scan takes vectorised operators (given that the operations across the leaves of the computational tree are batched).

thermox/sampler.py Outdated Show resolved Hide resolved
@SamDuffield
Copy link
Contributor Author

Update: associative_scan now working but seems like something is wrong with the calculations so I need to check the maths again

@SamDuffield
Copy link
Contributor Author

Ok I fixed the maths! At the cost of doubling the number of expm_vp calls, we might be able to halve it again with further thought although I'm not sure.

Next step is to add associative_scan for log_prob

@SamDuffield SamDuffield marked this pull request as ready for review June 3, 2024 12:50
tests/test_sampler.py Outdated Show resolved Hide resolved
thermox/sampler.py Outdated Show resolved Hide resolved
@KaelanDt
Copy link
Contributor

KaelanDt commented Jun 4, 2024

Finished the speedup comparison with and without associative scan and adapted the handling of random keys in _sample_identity_diffusion, should be good to review

@SamDuffield
Copy link
Contributor Author

Be sure to add the underscore to sample_identity_diffusion to become _sample_identity_diffusion and maybe remove the docstring too

Copy link
Contributor

@KaelanDt KaelanDt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@KaelanDt KaelanDt merged commit 3446bd4 into main Jun 5, 2024
2 checks passed
@KaelanDt KaelanDt deleted the associative_scan branch June 5, 2024 12:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants