-
Notifications
You must be signed in to change notification settings - Fork 661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft: DroQ and TD3+TQC jax implementation #272
base: master
Are you sure you want to change the base?
Conversation
The latest updates on your projects. Learn more about Vercel for Git ↗︎
|
This reverts commit d5704b3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👀 how does Adan perform?
Results are very preliminary, ADAN performs on par or slightly better than ADAM, but nothing significant yet. |
FYI https://github.com/deepmind/distrax might be a better replacement for tensorflow probability |
fwiw you can also use tensorflow_probability with a jax backend and then you don't need to use tensorflow at all (in one of their tutorials they even explicitly unninstall tf) |
Fyi, I converted that single file to a proof of concept of SB3 + Jax (SBX): https://github.com/araffin/sbx |
Description
FYI: unpolished jax implementation of TD3+DroQ and TD3+TQC implementations.
Related to #262 #258
My plan is to try to have sac in jax, but currently jax rely on tensorflow for probability distributions :/
So I adapted TD3 instead.
I also want to make it even faster but would need to tweak a bit the way the replay buffer is used.
EDIT: apparently tfd doesn't depends on tf anymore for latest version: https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX
Reference:
EDIT: SBX = SB3 + JAX: https://github.com/araffin/sbx
Known difference with original implementation: qf are updated at the same time of the actor instead of after each gradient step.Types of changes
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.
--capture-video
flag toggled on (required).mkdocs serve
.width=500
andheight=300
).