Skip to content

Commit

Permalink
Negative FRC was because optax was minimizing, switched single CTF to…
Browse files Browse the repository at this point in the history
… be a loss function and use optax properly. Layered still broken
  • Loading branch information
AnyaPorter committed Dec 17, 2024
1 parent f36e48f commit fe9b885
Showing 1 changed file with 39 additions and 5 deletions.
44 changes: 39 additions & 5 deletions programs/e3make3d_gauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def main():
elif options.ctf==2:
dsapix=apix*nxraw/ptclsfds.shape[1]
step0,qual0,shift0,sca0=gradient_step_layered_ctf(gaus,ptclsfds,orts,jax_downsample_2d(ctf_stack.jax,ptclsfds.shape[1]),tytx,dfrange,dfstep,dsapix,stage[3],stage[7],frc_Z)
jax.debug.print("{s}", s=step0)
step0=jnp.nan_to_num(step0)
if j==0:
step,qual,shift,sca=step0,qual0,shift0,sca0
Expand All @@ -227,13 +226,13 @@ def main():
shift+=shift0
sca+=sca
elif options.ctf==1:
step0,qual0,shift0,sca0=gradient_step_ctf(gaus,ptclsfds,orts,jax_downsample_2d(ctf_stack.jax,ptclsfds.shape[1]),tytx,dfrange,dfstep,stage[3],stage[7],frc_Z)
step0,qual0,shift0,sca0=gradient_step_ctf_optax(gaus,ptclsfds,orts,jax_downsample_2d(ctf_stack.jax,ptclsfds.shape[1]),tytx,dfrange,dfstep,stage[3],stage[7],frc_Z)
step0=jnp.nan_to_num(step0)
if j==0:
step,qual,shift,sca=step0,qual0,shift0,sca0
step,qual,shift,sca=step0,-qual0,shift0,sca0
else:
step+=step0
qual+=qual0
qual-=qual0
shift+=shift0
sca+=sca
# optimize gaussians and image shifts
Expand Down Expand Up @@ -272,7 +271,6 @@ def main():
update, optim_state = optim.update(step, optim_state)
gaus._data = optax.apply_updates(gaus._data, update)


if options.savesteps: from_numpy(gaus.numpy).write_image("steps.hdf",-1)

print(f"{i}: {qual:1.5f}\t{shift:1.5f}\t\t{sca:1.5f}\t{imshift:1.5f}")
Expand Down Expand Up @@ -530,6 +528,7 @@ def gradient_step_ctf(gaus,ptclsfds,orts,ctfaryds,tytx,dfrange,dfstep,weight=1.0
# print("mx ",mx.shape)

frcs,grad=gradvalfn_ctf(gausary,mx,ctfaryds,dfrange[0],dfrange[1],dfstep,tytx,ptcls,weight,frc_Z)

# qual=frcs.mean() # this is the average over all projections, not the average over frequency
qual=frcs # functions used in jax gradient can't return a list, so frcs is a single value now
shift=grad[:,:3].std() # translational std
Expand All @@ -553,6 +552,41 @@ def prj_frc_ctf(gausary,mx2d,ctfary,dfmin,dfmax,dfstep,tytx,ptcls,weight,frc_Z):

gradvalfn_ctf=jax.value_and_grad(prj_frc_ctf)


def gradient_step_ctf_optax(gaus,ptclsfds,orts,ctfaryds,tytx,dfrange,dfstep,weight=1.0,relstep=1.0,frc_Z=3.0):
"""Computes one gradient step on the Gaussian coordinates given a set of particle FFTs at the appropriate scale,
computing FRC to axial Nyquist, with specified linear weighting factor (def 1.0). Linear weight goes from
0-2. 1 is unweighted, >1 upweights low resolution, <1 upweights high resolution.
returns step, qual, shift, scale
step - one gradient step to be applied with (gaus.add_tensor)
qual - mean frc
shift - std of xyz shift gradient
scale - std of amplitude gradient"""
ny=ptclsfds.shape[1]
mx=orts.to_mx2d(swapxy=True)
gausary=gaus.jax
ptcls=ptclsfds.jax

frcs,grad=gradvalfnl_ctf(gausary,mx,ctfaryds,dfrange[0],dfrange[1],dfstep,tytx,ptcls,weight,frc_Z)

qual=frcs # functions used in jax gradient can't return a list, so frcs is a single value now
shift=grad[:,:3].std() # translational std
sca=grad[:,3].std() # amplitude std
xyzs=relstep/(shift*500) # xyz scale factor, 1000 heuristic, TODO: may change

return (grad,float(qual),float(shift),float(sca))

def prj_frc_loss_ctf(gausary,mx2d,ctfary,dfmin,dfmax,dfstep,tytx,ptcls,weight,frc_Z):
"""Aggregates the functions we need to calculate the gradient through. Computes the frc array resulting from the
comparison of the Gaussians in gaus to particles in known orientations."""

ny=ptcls.shape[1]
prj=gauss_project_ctf_fn(gausary,mx2d,ctfary,ny,dfmin,dfmax,dfstep,tytx)
return -jax_frc_jit(jax_fft2d(prj),ptcls,weight,2,frc_Z)

gradvalfnl_ctf=jax.value_and_grad(prj_frc_loss_ctf)


def gradient_step_layered_ctf(gaus,ptclsfds,orts,ctfaryds,tytx,dfrange,dfstep,dsapix,weight=1.0,relstep=1.0,frc_Z=3.0):
"""Computes one gradient step on the Gaussian coordinates given a set of particle FFTs at the appropriate scale,
computing FRC to axial Nyquist, with specified linear weighting factor (def 1.0). Linear weight goes from
Expand Down

0 comments on commit fe9b885

Please sign in to comment.