Lookahead optimizer ("Lookahead Optimizer: k steps forward, 1 step back") for tensorflow
This code is implemmented and tested with tensorflow 1.11.0. and 1.13.0.
I didn't use any special operator, so it should also work for other version of tensorflow.
I didn't directly wrap the optimizer, but make the lookahead strategy independent.
Thus, it's more flexible to decide what should be optimized with lookahead.
- Please assert the class after all variable initialization, and initialize the BaseLoookAhead with all trainable variables.
import tensorflow as tf
from lookahead_opt import BaseLookAhead
"""
Build your model here
Please also include any optimizer you need.
"""
model_vars = [v for v in tf.trainable_variables()]
tf.global_variables_initializer().run()
lookahead = BaseLookAhead(model_vars, k=5, alpha=0.5)
Arguments are define as follows:
model_vars
: the variables to be lookahead. [list]
k
: the number of steps that fast weights go forward. [int]
alpha
: The learning rate for merging slow to fast weight. [float]
- Add the assign operator to training operation or directly run in session.
# Add to train_op
train_op += lookahead.get_ops()
# Or just run the Session
with tf.Session() as sess:
_ = sess.run(lookahead.get_ops())
The Lookahead is wrapped with default variable_scope
"lookahead".
After calling BaseLookAhead
with specific variables, the variables will be injected to lookahead.
Noted that, the lookahead class is totally separated from optimizer, please remember to add optimizer when creating training graph.
The BaseLookAhead
will create duplicated tf.Variables
to save the slow weight.
And a counter will be automatically created to do "k steps forward, 1 step back".
I have conduct experiments on a many-to-many recursive task with stacked weight-dropped LSTM, proposed in "Regularizing and Optimizing LSTM Language Models".
Using lookahead with Adam, the training loss is higher than the model without lookahead. But the validation loss with lookahead is slightly better.
Code work by Jia-Yau Shiau [email protected].