Other inputs to the model than the system configuration #1546
Replies: 2 comments 3 replies
-
I don't exactly understand what you want to do.. could you write it down in formulas or pseudo code? |
Beta Was this translation helpful? Give feedback.
-
Sorry about the brevity. Currently, we can only define models with the configurations (ex. spins) as input. Then the variational state interface and the sampler interface samples this function using Monte Carlo. class MyRBMModel:
def __call__(self, sigma):
# do stuff, return log_amplitude
return f(sigma)
model = MyRBMModel()
vstate = nk.vqs.MCState(model, sampler)
H = sum(sigmaz(hi, i) for in range(L))
E, E_grad = vstate.expect_and_grad(H) But I would like to be able to calculate the expect_and_grad whilst providing other inputs, like this: class MyRBMModel:
def __call__(self, sigma, other_params):
# do stuff, return log_amplitude
return f(sigma, other_params)
model = MyRBMModel()
vstate = nk.vqs.MCState(model, sampler)
for i in range(n_iters):
other_params = 5*i
H = sum(other_param*sigmaz(hi, j) for j in range(L))
E, E_grad = vstate.expect_and_grad(H, other_params)
vstate.parameters = jax.tree_map(...E_grad...) Right now, what I do is something like: class MyRBMModel:
def __call__(self, sigma):
# do stuff, return log_amplitude
return f(sigma, self.params)
def set_params(self, params):
self.params = params
para = None
for i in range(n_iters):
other_params = 0.5*i
model = MyRBMModel()
model.set_params(other_params)
vstate = nk.vqs.MCState(model, sampler)
H = sum(other_params*sigmaz(hi, j) for j in range(L))
E, E_grad = vstate.expect_and_grad(H, other_params)
vstate.parameters = jax.tree_map(...E_grad...)
para = vstate.parameters.copy()
del vstate
del sampler
del model
If I don't do the delete step, something goes wrong, I think it is due to just-in-time compliation. But it is also extremely slow, compared to a normal loop. I would like to provide input parameters to the model at each iteration, which seems like it's not possible right now. Let me know if this is still not clear. |
Beta Was this translation helpful? Give feedback.
-
Right now, there is no way to provide other parameters to the model (say, RBM), than the spin configuration at each iteration.
I can kind of do it by hand, by creating a new model at each iteration, setting those values, than creating a new vstate, etc., and then doing sgd by hand, but because it is jitted each time, it's very slow.
Is something like this possible?
Beta Was this translation helpful? Give feedback.
All reactions