Replies: 1 comment
-
What matrix? You are talking about the matrix in your ansatz or something else?
It really depends on your ansatz.
Again, without a small snippet it is hard to say anything. But if you are using VMC + SR the default value of The regularisation of the S matrix is essentially combining the "natural gradient" or "SR gradient" with the standard "stochastic gradient". Larger regularisations will give you only a "stochastic gradient" while small regularisations will give your a more "natural gradient". In a sense, large regularisations of the S matrix are very close to not using SR at all ( It might be that
Due to the way we compute the S matrix, we can't easily do that. In short, we never fully compute the S matrix, but we only compute the Jacobian def special_solver(S, v, *args):
# convert from our "lazy" format to the dense format
Sdense = S.to_dense()
# set the matrix elements according to what you want.
Sdense = Sdense.at[0,0].set(0.0)
#...
# The S is now dense, but the vector is a pytreee. convert vector from "pytree" format to dense as well.
v, unravel = nk.jax.tree_ravel(v)
# compute the solution
sol, info = jax.scipy.sparse.linalg.cg(Sdense, v)
# convert back the solution from dense to pytree
sol = unravel(sol)
return sol, info
# try this function out to check that it works...
# first generate a mock vector:
_, vec = vstate.expect_and_grad(ha)
S = vstate.quantum_geometric_tensor()
sol, info = special_solver(S, vec)
# you should check that sol has the same pytree structure as vector:
# Important! it must be jittable, otherwise it won't work
sol, info = jax.jit(special_solver)(S, vec)
# if this works as well, you can pass it to netket
SR = nk.optimizer.SR(diag_shift=..., solver=special_solver)
vmc = nk.VMC(..., preconditioner=SR) However, In general, if you want to exploit the symmetries of your system, you should just encode them in the Neural Network/variational ansatz. The S matrix will then have the right properties... |
Beta Was this translation helpful? Give feedback.
-
Hi everyone,
I'm working on a periodic system which Hamiltonian (with a Z_2 symmetry) acts on a nonhomogeneous Hilbert space, then my matrix and S-matrix should have non-zero elements close to the diagonal and on both extremes of the anti-diagonal, but if I use stochastic reconfiguration, the
diag_shift
parameter discard all the elements that are not enough close to the diagonal. I made both simulations with a small system, in order to get the exact ground energy to compare. When I usediag_shift=0.1
, the ground variational energy does not fit the exact solution, but it does when I don't specify thediag_shift
parameter. My question is: is there a way to consider only the diagonal and anti-diagonal in order to be able to capture the periodicity of the system in the S-matrix without considering all the zero elements?Beta Was this translation helpful? Give feedback.
All reactions