Skip to content

Commit

Permalink
Ok, I think I've finally fixed the "training problem" with e2gmm and …
Browse files Browse the repository at this point in the history
…related programs. The problem was 2-fold: The network adjacent to the latent space was expanding/contracting too slowly, allowing only a small number of weights to dominate. The major problem, however, is that while the relu activation function _can_ get stuck with zero outputs and an untrainable state in certain situations, particularly when the network is small. Using leaky_relu instead seems to avoid these issues and make more efficient use of the network.
  • Loading branch information
sludtke42 committed Nov 26, 2024
1 parent f7c09fc commit 4c07420
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
38 changes: 19 additions & 19 deletions programs/e2gmm_refine_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,15 +759,15 @@ def build_encoder(ninp,nmid,grps=None):
print(f"Encoder (no groups) {max(ninp//2,nmid*8)},{max(ninp//4,nmid*4)},{max(ninp//8,nmid*4)},{max(ninp//16,nmid*4)},{max(ninp//32,nmid*2)}")
layers=[
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(max(ninp//2,nmid*8), activation="relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True,bias_initializer=binit),
tf.keras.layers.Dense(max(ninp//2,nmid*8), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True,bias_initializer=binit),
# tf.keras.layers.Dropout(.2),
tf.keras.layers.Dense(max(ninp//4,nmid*4), activation="relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True),
tf.keras.layers.Dense(max(ninp//4,nmid*8), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True),
tf.keras.layers.Dropout(.4),
tf.keras.layers.Dense(max(ninp//8,nmid*4), activation="relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True),
tf.keras.layers.Dense(max(ninp//8,nmid*4), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True),
tf.keras.layers.Dropout(.4),
tf.keras.layers.Dense(max(ninp//16,nmid*4), activation="relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True),
tf.keras.layers.Dense(max(ninp//16,nmid*4), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True),
tf.keras.layers.Dropout(.4),
tf.keras.layers.Dense(max(ninp//32,nmid*2), activation="relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True),
tf.keras.layers.Dense(max(ninp//32,nmid*4), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True),
tf.keras.layers.Dense(nmid, kernel_regularizer=l2, kernel_initializer=kinit,use_bias=True),
]

Expand All @@ -782,15 +782,15 @@ def build_encoder(ninp,nmid,grps=None):
in2s=[]
t=0
for i in ngrp:
in2s.append(tf.keras.layers.Dense(i, activation="relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True,bias_initializer=binit)(in1[:,t:t+i]))
in2s.append(tf.keras.layers.Dense(i, activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True,bias_initializer=binit)(in1[:,t:t+i]))
t+=i
# Add Dropout here?
# drop=[tf.keras.layers.Dropout(0.3)(in2s[i]) for i in range(len(ngrp))]
mid=[tf.keras.layers.Dense(max(ngrp[i]//2,latpergrp*8), activation="relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True)(in2s[i]) for i in range(len(ngrp))]
mid=[tf.keras.layers.Dense(max(ngrp[i]//2,latpergrp*8), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True)(in2s[i]) for i in range(len(ngrp))]
# drop=[tf.keras.layers.Dropout(0.25)(mid[i]) for i in range(len(ngrp))]
mid2=[tf.keras.layers.Dense(max(ngrp[i]//4,latpergrp*4), activation="relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True)(mid[i]) for i in range(len(ngrp))]
mid3=[tf.keras.layers.Dense(max(ngrp[i]//8,latpergrp*2), activation="relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True)(mid2[i]) for i in range(len(ngrp))]
mid4=[tf.keras.layers.Dense(max(ngrp[i]//16,latpergrp*2), activation="relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True)(mid3[i]) for i in range(len(ngrp))]
mid2=[tf.keras.layers.Dense(max(ngrp[i]//4,latpergrp*8), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True)(mid[i]) for i in range(len(ngrp))]
mid3=[tf.keras.layers.Dense(max(ngrp[i]//8,latpergrp*4), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True)(mid2[i]) for i in range(len(ngrp))]
mid4=[tf.keras.layers.Dense(max(ngrp[i]//16,latpergrp*4), activation="leaky_relu", kernel_initializer=kinit, kernel_regularizer=l2,use_bias=True)(mid3[i]) for i in range(len(ngrp))]
outs=[tf.keras.layers.Dense(latpergrp, kernel_regularizer=l2, kernel_initializer=kinit,use_bias=True)(mid4[i]) for i in range(len(ngrp))]
out=tf.keras.layers.Concatenate()(outs)
encode_model=tf.keras.Model(inputs=in1,outputs=out)
Expand Down Expand Up @@ -826,18 +826,18 @@ def build_decoder(nmid, pt ):

# print(f"Decoder {max(nout//32,nmid)} {max(nout//8,nmid)} {max(nout//2,nmid)}")
layers=[
#tf.keras.layers.Dense(nmid*2,activation="relu",use_bias=True,bias_initializer=kinit,kernel_constraint=Localize1()),
#tf.keras.layers.Dense(nmid*4,activation="relu",use_bias=True,kernel_constraint=Localize2()),
#tf.keras.layers.Dense(nmid*8,activation="relu",use_bias=True,kernel_constraint=Localize3()),
tf.keras.layers.Dense(min(nmid*2,nout//16),activation="relu",kernel_initializer=kinit,use_bias=True,bias_initializer=binit),
tf.keras.layers.Dense(min(nmid*4,nout//16),activation="relu",kernel_initializer=kinit,use_bias=True),
tf.keras.layers.Dense(min(nmid*8,nout//16),activation="relu",kernel_initializer=kinit,use_bias=True),
#tf.keras.layers.Dense(nmid*2,activation="leaky_relu",use_bias=True,bias_initializer=kinit,kernel_constraint=Localize1()),
#tf.keras.layers.Dense(nmid*4,activation="leaky_relu",use_bias=True,kernel_constraint=Localize2()),
#tf.keras.layers.Dense(nmid*8,activation="leaky_relu",use_bias=True,kernel_constraint=Localize3()),
tf.keras.layers.Dense(max(nmid*4,nout//64),activation="leaky_relu",kernel_initializer=kinit,use_bias=True,bias_initializer=binit),
tf.keras.layers.Dense(max(nmid*8,nout//32),activation="leaky_relu",kernel_initializer=kinit,use_bias=True),
tf.keras.layers.Dense(min(nmid*8,nout//16),activation="leaky_relu",kernel_initializer=kinit,use_bias=True),
tf.keras.layers.Dropout(.4),
tf.keras.layers.Dense(min(nmid*16,nout//8),activation="relu",kernel_initializer=kinit,use_bias=True),
tf.keras.layers.Dense(min(nmid*16,nout//8),activation="leaky_relu",kernel_initializer=kinit,use_bias=True),
tf.keras.layers.Dropout(.4),
tf.keras.layers.Dense(nout//4,activation="relu",kernel_initializer=kinit,use_bias=True),
tf.keras.layers.Dense(nout//8,activation="leaky_relu",kernel_initializer=kinit,use_bias=True),
tf.keras.layers.Dropout(.25),
tf.keras.layers.Dense(nout//2,activation="relu",kernel_initializer=kinit,use_bias=True),
tf.keras.layers.Dense(nout//4,activation="leaky_relu",kernel_initializer=kinit,use_bias=True),
# tf.keras.layers.BatchNormalization(),
layer_output,
tf.keras.layers.Reshape((nout,4))
Expand Down
3 changes: 3 additions & 0 deletions programs/e3make3d_gauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def main():
# if not options.tomo or sn<2:
if True:
step0,qual0,shift0,sca0=gradient_step(gaus,ptclsfds,orts,tytx,stage[3],stage[7],frc_Z)
step0=jnp.nan_to_num(step0)
if j==0:
step,qual,shift,sca=step0,qual0,shift0,sca0
else:
Expand All @@ -210,6 +211,7 @@ def main():
elif options.ctf:
dsapix=apix*nxraw/ptclsfds.shape[1]
step0,qual0,shift0,sca0=gradient_step_ctf(gaus,ptclsfds,orts,ctf_stack.downsample(ptclsfds.shape[1]),tytx,dfrange,dfstep,dsapix,stage[3],stage[7])
step0=jnp.nan_to_num(step0)
if j==0:
step,qual,shift,sca=step0,qual0,shift0,sca0
else:
Expand All @@ -220,6 +222,7 @@ def main():
# optimize gaussians and image shifts
else:
step0,stept0,qual0,shift0,sca0,imshift0=gradient_step_tytx(gaus,ptclsfds,orts,tytx,stage[3],stage[7])
step0=jnp.nan_to_num(step0)
if j==0:
step,stept,qual,shift,sca,imshift=step0,stept0,qual0,shift0,sca0,imshift0
caches[stage[1]].add_orts(nliststg[j:j+512],None,stept0*rstep) # we can immediately add the current 500 since it is per-particle
Expand Down

0 comments on commit 4c07420

Please sign in to comment.