@partial function not working in custom operator definition #1063
Answered
by
PhilipVinc
singhavishek
asked this question in
Q&A
-
I am trying to understand custom operatir defining in Netket documentation given in this link `---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-10-4871a11b9c4d> in <module>()
11 return eta, jnp.ones(1)
12
---> 13 @partial(jax.vmap, in_axes=(None, None, 0, 0, 0))
14 def e_loc(logpsi, pars, sigma, eta, mels):
15 return jnp.sum(mels * jnp.exp(logpsi(pars, eta) - logpsi(pars, eta)), axis=-1)
NameError: name 'partial' is not defined`
the cell is,
`@jax.vmap
def get_conns_and_mels(sigma):
# get number of spins
N = sigma.shape[-1]
# repeat eta N times
eta = jnp.tile(sigma, (N,1))
# diagonal indices
ids = np.diag_indices(N)
# flip those indices
eta = eta.at[ids].set(-eta.at[ids].get())
return eta, jnp.ones(1)
@partial(jax.vmap, in_axes=(None, None, 0,0,0))
def e_loc(logpsi, pars, sigma, eta, mels):
return jnp.sum(mels * jnp.exp(logpsi(pars, eta) - logpsi(pars, eta)), axis=-1)` |
Beta Was this translation helpful? Give feedback.
Answered by
PhilipVinc
Jan 13, 2022
Replies: 1 comment 6 replies
-
|
Beta Was this translation helpful? Give feedback.
6 replies
Answer selected by
PhilipVinc
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
from functools import partial