-
Notifications
You must be signed in to change notification settings - Fork 11
/
train.py
44 lines (31 loc) · 1.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2017 Hiroaki Santo
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
import params
from dataset import DpsnDataset
from model import DpsnModel
tf.app.flags.DEFINE_string("dataset_path", params.DATASET_PATH, "")
tf.app.flags.DEFINE_string("output_path", "./model", "")
tf.app.flags.DEFINE_integer("batch_size", 1000, "")
tf.app.flags.DEFINE_integer("steps", 40000, "")
tf.app.flags.DEFINE_integer("gpu", 0, "gpu id")
FLAGS = tf.app.flags.FLAGS
tf_config = tf.ConfigProto(
gpu_options=tf.GPUOptions(
visible_device_list="{}".format(FLAGS.gpu),
allow_growth=True,
)
)
def main(_):
train_dataset = DpsnDataset(dataset_path=os.path.join(FLAGS.dataset_path, "train"))
test_dataset = DpsnDataset(dataset_path=os.path.join(FLAGS.dataset_path, "test"))
with tf.Session(config=tf_config) as sess:
model = DpsnModel(sess=sess, output_path=FLAGS.output_path, light_num=train_dataset.light_num)
model.train(train_dataset, test_dataset, FLAGS.batch_size, FLAGS.steps)
if __name__ == '__main__':
tf.app.run(main=main)