-
Notifications
You must be signed in to change notification settings - Fork 0
/
norm_update_manager.py
146 lines (127 loc) · 5.33 KB
/
norm_update_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import re
import shutil
from typing import Dict, Generator, List, Set, Tuple
import tempfile
import pynini
import norm_rule_finetuning
from norm_rule_finetuning import load_whitelist_file
import norm_config
from pydantic import BaseModel
def create_updated_case_sensitive_whitelist(
words_to_update: Dict[str, str],
whitelist_file: str,
out_file: str,
) -> "pynini.FstLike":
word_dict = words_to_update
fw = open(out_file, "w", encoding="UTF-8")
wrote_words = set()
for word_ori, word_repl in load_whitelist_file(whitelist_file):
word_ori_check = word_ori
if word_ori_check in word_dict:
word_repl = word_dict[word_ori_check]
wrote_words.add(word_ori_check)
fw.write(f"{word_ori_check}\t{word_repl}\n")
for w_ori, w_repl in word_dict.items():
if w_ori not in wrote_words:
fw.write(f"{w_ori}\t{w_repl}\n")
fw.close()
def create_updated_case_insensitive_whitelist(
words_to_update: Dict[str, str],
whitelist_file: str,
out_file: str,
) -> "pynini.FstLike":
# make "word" in case insensitive dict lower
word_dict = {}
for w_ori, w_repl in words_to_update.items():
word_dict[w_ori.lower()] = w_repl
fw = open(out_file, "w", encoding="UTF-8")
wrote_words = set()
for word_ori, word_repl in load_whitelist_file(whitelist_file):
word_ori_check = word_ori.lower()
if word_ori_check in word_dict:
word_repl = word_dict[word_ori_check]
wrote_words.add(word_ori_check)
fw.write(f"{word_ori_check}\t{word_repl}\n")
for w_ori, w_repl in word_dict.items():
if w_ori not in wrote_words:
fw.write(f"{w_ori}\t{w_repl}\n")
fw.close()
def create_words_removed_whitelist(
words_tobe_removed: Set[str],
whitelist_file: str,
case_sensitive: bool,
out_file: str,
):
if case_sensitive:
word_set = words_tobe_removed
else:
# make "word" in case insensitive dict lower
word_set = set()
for w in words_tobe_removed:
word_set.add(w.lower())
fw = open(out_file, "w", encoding="UTF-8")
for word_ori, word_repl in load_whitelist_file(whitelist_file):
word_check = word_ori if case_sensitive else word_ori.lower()
if word_check not in word_set:
fw.write(f"{word_ori}\t{word_repl}\n")
fw.close()
def handle_update_command(words_to_update: Dict[str, str], case_sensitive: bool):
tmp_whitelist_filename = f"tts-norm-whitelist-{'case_sensitive' if case_sensitive else 'case_insensitive'}.tmp"
tmp_whitelist_file = os.path.join(tempfile.tempdir, tmp_whitelist_filename)
case_sensitive_whitelist_file = None
case_insensitive_whitelist_file = None
if case_sensitive:
case_sensitive_whitelist_file = norm_config.CASE_SENSITIVE_WHITELIST_FILE
create_updated_case_sensitive_whitelist(
words_to_update,
case_sensitive_whitelist_file,
out_file=tmp_whitelist_file,
)
else:
case_insensitive_whitelist_file = norm_config.CASE_INSENSITIVE_WHITELIST_FILE
create_updated_case_insensitive_whitelist(
words_to_update,
case_insensitive_whitelist_file,
out_file=tmp_whitelist_file,
)
ori_far_file = norm_config.NORM_MODEL_ORIGIN
ori_far_filename = os.path.split(ori_far_file)[-1]
tmp_far_file = os.path.join(tempfile.tempdir, f"tts-norm-{ori_far_filename}.tmp")
norm_rule_finetuning.apply_whitelist_files(
ori_far_file,
case_sensitive_whitelist_file=case_sensitive_whitelist_file,
case_insensitive_whitelist_file=case_insensitive_whitelist_file,
out_far_file=tmp_far_file,
)
def handle_remove_command(words_tobe_removed: Set[str], case_sensitive: bool):
tmp_whitelist_filename = f"tts-norm-whitelist-{'case_sensitive' if case_sensitive else 'case_insensitive'}.tmp"
tmp_whitelist_file = os.path.join(tempfile.tempdir, tmp_whitelist_filename)
case_sensitive_whitelist_file = None
case_insensitive_whitelist_file = None
if case_sensitive:
case_sensitive_whitelist_file = norm_config.CASE_SENSITIVE_WHITELIST_FILE
whitelist_file = case_sensitive_whitelist_file
else:
case_insensitive_whitelist_file = norm_config.CASE_INSENSITIVE_WHITELIST_FILE
whitelist_file = case_insensitive_whitelist_file
create_words_removed_whitelist(
words_tobe_removed, whitelist_file, case_sensitive, tmp_whitelist_file
)
ori_far_file = norm_config.NORM_MODEL_ORIGIN
ori_far_filename = os.path.split(ori_far_file)[-1]
tmp_far_file = os.path.join(tempfile.tempdir, f"tts-norm-{ori_far_filename}.tmp")
norm_rule_finetuning.apply_whitelist_files(
ori_far_file,
case_sensitive_whitelist_file=case_sensitive_whitelist_file,
case_insensitive_whitelist_file=case_insensitive_whitelist_file,
out_far_file=tmp_far_file,
)
def apply_updated_norm_model(new_norm_model_file: str) -> None:
shutil.move(new_norm_model_file, norm_config.NORM_MODEL_RUNNING)
def apply_whitelist(new_whitelist_file: str, case_sensitive: bool) -> None:
if case_sensitive:
ori_whitelist_file = norm_config.CASE_SENSITIVE_WHITELIST_FILE
else:
ori_whitelist_file = norm_config.CASE_INSENSITIVE_WHITELIST_FILE
shutil.move(new_whitelist_file, ori_whitelist_file)