- 使用create_train_file来创建训练样本
- 利用train来训练
- 使用Incre_train来增量学习
目前主要有两种恢复模型的方式
- 方式一
# 先构建网络结构
build_model()
# 初始化变量
sess.run(tf.global_variables_initializer())
# 最后从checkpoint中加载已训练好的参数
saver = tf.train.Saver()
saver.restore(self.sess, init_checkpoint)
- 方式二
# 先构建网络结构
build_model()
# 调用init_from_checkpoint方法
tvars = tf.trainable_variables()
(assignment_map,
initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
tvars, init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
# 最后初始化变量
sess.run(tf.global_variables_initializer())
- 保存模型
saver.save(sess, model_path + '/model.ckpt')
sess = tf.Session(graph=graph)
check_point_path = 'saved_model/' # 保存好模型的文件路径
ckpt = tf.train.get_checkpoint_state(checkpoint_dir=check_point_path)
saver.restore(sess, ckpt.model_checkpoint_path)
参考的这个文章增量学习
- get_train_examples(train_file_path) 获取样本,输出结果为一个List,里面存的是一个类(包含feature和label)
- file_based_convert_examples_to_features(examples, output_file) 将List里面的数据写入到tf_record格式的文件中,利用tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))或者int生成相对应的feature,然后定义一个tf.train.Example,写入文件。
- file_based_input_fn_builder(input_file, batch_size, is_training) 利用tf.data.TFRecordDataset读取tf_record格式的文件,并且batch的生产出来,里面要注意:repeat函数只有在训练时使用
- my_model(features, labels, mode, params) 自己定义的模型,输入参数feature里面既可以包含feature也可以包含label,mode是用来判断被哪个方法调用。里面需要对train、eval、predict分别进行定义,my_model函数的返回值均是tf.estimator.EstimatorSpec,特别主要在eval的时候不能使用常规的准确率,可以使用tf.metrics函数(他是根据batch来计算最终的)。
直接用estimator_test文件下的train.py即可。