Replies: 2 comments 2 replies
-
For debugging purposes that's easily achieved (taken from jax-ml/jax#252 ):
However for proper usage (jitting only part of the model) that is going to be quite hard: the whole of netket is written assuming a functional model. Why do you want to jit only parts of the model? |
Beta Was this translation helpful? Give feedback.
-
Jitting part of the model is mainly relevant when you try something new, and e.g. use external packages, or custom python implementations etc. However, I came across this today: jax-ml/jax#766 which might be a good enough hack to accomplish something like that. |
Beta Was this translation helpful? Give feedback.
-
Hi,
I'm wondering whether it would be feasible to switch off jax jitting in certain sections of netket with some environment variables.
For example, it's my understanding that currently any model that we develop needs to be jittable.
This also makes debugging rather challenging...
It would be also useful though, to allow people to jit only subroutines in their model.
Is there currently an option to do this, without just switching off jax jitting overall?
What would be necessary to implement such a thing, and is this feasible at all?
Jannes
Beta Was this translation helpful? Give feedback.
All reactions