Skip to content

利用tensorflow实现的一个增量学习的小Demo,并改为Estimator版本

Notifications You must be signed in to change notification settings

cui-z/Incremental-Learning

Repository files navigation

增量学习的一个Demo

使用流程

  • 使用create_train_file来创建训练样本
  • 利用train来训练
  • 使用Incre_train来增量学习

核心代码

恢复模型

目前主要有两种恢复模型的方式

  • 方式一
# 先构建网络结构
build_model()
# 初始化变量
sess.run(tf.global_variables_initializer())
# 最后从checkpoint中加载已训练好的参数
saver = tf.train.Saver()
saver.restore(self.sess, init_checkpoint)
  • 方式二
# 先构建网络结构
build_model()
# 调用init_from_checkpoint方法
tvars = tf.trainable_variables()
(assignment_map,
initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
   tvars, init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
# 最后初始化变量
sess.run(tf.global_variables_initializer())

本文采用的

  • 保存模型
saver.save(sess, model_path + '/model.ckpt')
sess = tf.Session(graph=graph)
check_point_path = 'saved_model/'  # 保存好模型的文件路径
ckpt = tf.train.get_checkpoint_state(checkpoint_dir=check_point_path)
saver.restore(sess, ckpt.model_checkpoint_path)

参考的这个文章增量学习

训练过程修改为高级API estimator版本

流程(根据Bert代码仿写)

  • get_train_examples(train_file_path) 获取样本,输出结果为一个List,里面存的是一个类(包含feature和label)
  • file_based_convert_examples_to_features(examples, output_file) 将List里面的数据写入到tf_record格式的文件中,利用tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))或者int生成相对应的feature,然后定义一个tf.train.Example,写入文件。
  • file_based_input_fn_builder(input_file, batch_size, is_training) 利用tf.data.TFRecordDataset读取tf_record格式的文件,并且batch的生产出来,里面要注意:repeat函数只有在训练时使用
  • my_model(features, labels, mode, params) 自己定义的模型,输入参数feature里面既可以包含feature也可以包含label,mode是用来判断被哪个方法调用。里面需要对train、eval、predict分别进行定义,my_model函数的返回值均是tf.estimator.EstimatorSpec,特别主要在eval的时候不能使用常规的准确率,可以使用tf.metrics函数(他是根据batch来计算最终的)。

使用方法

直接用estimator_test文件下的train.py即可。

About

利用tensorflow实现的一个增量学习的小Demo,并改为Estimator版本

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages