forked from microsoft/CNTK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
txt2ctf.py
149 lines (127 loc) · 6.73 KB
/
txt2ctf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python
# This script takes a list of dictionary files and a plain text utf-8 file and converts this text input file to CNTK text format.
#
# The input text file must contain N streams per line (N TAB-separated "columns") and should be accompanied by N dictionary files.
# The input text file must be in the following form:
# text1 TAB text2 TAB ... TAB textN
# .....
# where each line represents one sequence across all N input streams.
# Each text consists of one or more space-separated word tokens (samples).
#
# Dictionary files are text files that are required to be specified for all streams,
# so the #dictionaries = #columns in the input file.
# A dictionary contains a single token per line. The zero-based line number becomes the numeric index
# of the token in the output CNTK text format file.
# Example usage (i.e. for PennTreebank files):
# 1)
# sed -e 's/^<\/s> //' -e 's/ <\/s>$//' < en.txt > en.txt1
# sed -e 's/^<\/s> //' -e 's/ <\/s>$//' < fr.txt > fr.txt1
# paste en.txt1 fr.txt1 | txt2ctf.py --map en.dict fr.dict > en-fr.ctf
#
# 2) (assuming that the current dir is [cntk root]/Examples/SequenceToSequence/CMUDict/Data/)
# sed -e 's/<s\/>/<\/s>\t<s>/' < cmudict-0.7b.train-dev-1-21.txt `#this will replace every '<s/>' with '</s>[tab]<s>'` |\
# python ../../../../Scripts/txt2ctf.py --map cmudict-0.7b.mapping cmudict-0.7b.mapping > cmudict-0.7b.train-dev-1-21.ctf
#
import sys
import argparse
import re
def convert(dictionaryStreams, inputs, output, unk, annotated):
# create in memory dictionaries
dictionaries = [{ line.rstrip('\r\n').strip():index for index, line in enumerate(dic) } for dic in dictionaryStreams]
# convert inputs
for input in inputs:
sequenceId = 0
for index, line in enumerate(input):
line = line.rstrip('\r\n')
columns = line.split("\t")
if len(columns) != len(dictionaries):
raise Exception("Number of dictionaries {0} does not correspond to the number of streams in line {1}:'{2}'"
.format(len(dictionaries), index, line))
_convertSequence(dictionaries, columns, sequenceId, output, unk, annotated)
sequenceId += 1
def _convertSequence(dictionaries, streams, sequenceId, output, unk, annotated):
tokensPerStream = [[t for t in s.strip(' ').split(' ') if t != ""] for s in streams]
maxLen = max(len(tokens) for tokens in tokensPerStream)
# writing to the output file
for sampleIndex in range(maxLen):
output.write(str(sequenceId))
for streamIndex in range(len(tokensPerStream)):
if len(tokensPerStream[streamIndex]) <= sampleIndex:
output.write("\t")
continue
token = tokensPerStream[streamIndex][sampleIndex]
if unk is not None and token not in dictionaries[streamIndex]: # try unk symbol if specified
token = unk
if token not in dictionaries[streamIndex]:
raise Exception("Token '{0}' cannot be found in the dictionary for stream {1}".format(token, streamIndex))
value = dictionaries[streamIndex][token]
output.write("\t|S" + str(streamIndex) + " "+ str(value) + ":1")
if annotated:
output.write(" |# " + re.sub(r'(\|(?!#))|(\|$)', r'|#', token))
output.write("\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Transforms text file given dictionaries into CNTK text format.")
parser.add_argument('--map', help='List of dictionaries, given in the same order as streams in the input files',
nargs="+", required=True)
parser.add_argument('--annotated', help='Whether to annotate indices with tokens. Default is false',
choices=["True", "False"], default="False", required=False)
parser.add_argument('--output', help='Name of the output file, stdout if not given', default="", required=False)
parser.add_argument('--input', help='Name of the inputs files, stdin if not given', default="", nargs="*", required=False)
parser.add_argument('--unk', help='Name fallback symbol for tokens not in dictionary (same for all columns)', default=None, required=False)
args = parser.parse_args()
# creating inputs
inputs = [sys.stdin]
if len(args.input) != 0:
inputs = [open(i, encoding="utf-8") for i in args.input]
# creating output
output = sys.stdout
if args.output != "":
output = open(args.output, "w")
convert([open(d, encoding="utf-8") for d in args.map], inputs, output, args.unk, args.annotated == "True")
#####################################################################################################
# Tests
#####################################################################################################
try:
import StringIO
stringio = StringIO.StringIO
except ImportError:
from io import StringIO
stringio = StringIO
try:
import pytest
except ImportError:
pass
def test_simpleSanityCheck():
dictionary1 = stringio("hello\nmy\nworld\nof\nnothing\n")
dictionary2 = stringio("let\nme\nbe\nclear\nabout\nit\n")
input = stringio("hello my\tclear about\nworld of\tit let clear\n")
output = stringio()
convert([dictionary1, dictionary2], [input], output, None, False)
expectedOutput = stringio()
expectedOutput.write("0\t|S0 0:1\t|S1 3:1\n")
expectedOutput.write("0\t|S0 1:1\t|S1 4:1\n")
expectedOutput.write("1\t|S0 2:1\t|S1 5:1\n")
expectedOutput.write("1\t|S0 3:1\t|S1 0:1\n")
expectedOutput.write("1\t\t|S1 3:1\n")
assert expectedOutput.getvalue() == output.getvalue()
def test_thatPipeSymbolIsEscaped():
dictionary1 = stringio("|hello\nm|y\nworl|d\nof\nnothing|\n")
dictionary2 = stringio("let|\nm|e\nb|#e\nclear\n||about\ni||#t\n")
input = stringio("|hello m|y\tclear ||about\nworl|d of\ti||#t let| clear\n")
output = stringio()
convert([dictionary1, dictionary2], [input], output, None, True)
expectedOutput = stringio()
expectedOutput.write("0\t|S0 0:1 |# |#hello\t|S1 3:1 |# clear\n")
expectedOutput.write("0\t|S0 1:1 |# m|#y\t|S1 4:1 |# |#|#about\n")
expectedOutput.write("1\t|S0 2:1 |# worl|#d\t|S1 5:1 |# i|#|#t\n")
expectedOutput.write("1\t|S0 3:1 |# of\t|S1 0:1 |# let|#\n")
expectedOutput.write("1\t\t|S1 3:1 |# clear\n")
for x in zip(output.getvalue().split('\n'), expectedOutput.getvalue().split('\n')):
assert x[0] == x[1]
def test_nonExistingWord():
dictionary1 = stringio("hello\nmy\nworld\nof\nnothing\n")
input = stringio("hello my\nworld of nonexistent\n")
output = stringio()
with pytest.raises(Exception) as info:
convert([dictionary1], [input], output, None, False)
assert str(info.value) == "Token 'nonexistent' cannot be found in the dictionary for stream 0"