-
Notifications
You must be signed in to change notification settings - Fork 502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Bures-Wasserstein Gradient Descent for Bures-Wasserstein Barycenters #680
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small comments. I will let @antoinecollas do a proper review he is the expert in Riemannian optimization
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #680 +/- ##
==========================================
+ Coverage 97.05% 97.09% +0.04%
==========================================
Files 98 98
Lines 19955 20167 +212
==========================================
+ Hits 19367 19582 +215
+ Misses 588 585 -3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great. A few tests especialy about errors are missing
ot/gaussian.py
Outdated
# check convergence | ||
if batch_size is not None and batch_size < n: | ||
# TODO: criteria for SGD: on gradients? + test SGD | ||
diff = nx.norm(Cb - Cnew) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not tested
Types of changes
This PR aims to add the Bures-Wasserstein gradient descent solver to compute Bures-Wasserstein barycenters (see e.g. Gradient descent algorithms for Bures-Wasserstein barycenters or Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent).
ot.gaussian.bures_wasserstein_barycenter
to allow to use different methodsot.gaussian.bures_barycenter_fixpoint
ot.gaussian.bures_barycenter_gradient_descent
test_bures_wasserstein_barycenter
test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter
ot.gaussian.bures_wasserstein_distance
Motivation and context / Related issue
The Bures-Wasserstein gradient descent comes with convergence guarantees to solve Bures-Wasserstein barycenters. Moreover, it can also be used in a stochastic way when there are too much Gaussian. Thus, it is a good alternative to the fixed-point algorithm currently implemented.
How has this been tested (if it applies)
I added a test
test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter
to assess both methods returns the same barycenter. I also added the itertools totest_bures_wasserstein_barycenter
.PR checklist