diff --git a/textsum_data_convert.py b/textsum_data_convert.py index 732d3d8..73f0397 100644 --- a/textsum_data_convert.py +++ b/textsum_data_convert.py @@ -16,6 +16,10 @@ import tensorflow as tf from tensorflow.core.example import example_pb2 +from numpy.random import seed as random_seed +from numpy.random import shuffle as random_shuffle + +random_seed(123) # Reproducibility FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('command', 'text_to_binary', @@ -28,6 +32,8 @@ def _text_to_binary(input_directories, output_filenames, split_fractions): filenames = _get_filenames(input_directories) + random_shuffle(filenames) + start_from_index = 0 for index, output_filename in enumerate(output_filenames): sample_count = int(len(filenames) * split_fractions[index])