forked from unitaryai/detoxify
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_prediction.py
101 lines (86 loc) · 2.89 KB
/
run_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
import os
import pandas as pd
from detoxify import Detoxify
def load_input_text(input_obj):
"""Checks input_obj is either the path to a txt file or a text string.
If input_obj is a txt file it returns a list of strings."""
if isinstance(input_obj, str) and os.path.isfile(input_obj):
if not input_obj.endswith(".txt"):
raise ValueError("Invalid file type: only txt files supported.")
text = open(input_obj).read().splitlines()
elif isinstance(input_obj, str):
text = input_obj
else:
raise ValueError("Invalid input type: input type must be a string or a txt file.")
return text
def run(model_name, input_obj, dest_file, from_ckpt, device="cpu"):
"""Loads model from checkpoint or from model name and runs inference on the input_obj.
Displays results as a pandas DataFrame object.
If a dest_file is given, it saves the results to a txt file.
"""
text = load_input_text(input_obj)
if model_name is not None:
model = Detoxify(model_name, device=device)
else:
model = Detoxify(checkpoint=from_ckpt, device=device)
res = model.predict(text)
res_df = pd.DataFrame(res, index=[text] if isinstance(text, str) else text).round(5)
print(res_df)
if dest_file is not None:
res_df.index.name = "input_text"
res_df.to_csv(dest_file)
return res
if __name__ == "__main__":
# args
parser = argparse.ArgumentParser()
parser.add_argument(
"--input",
type=str,
help="text, list of strings, or txt file",
)
parser.add_argument(
"--model_name",
default="unbiased",
type=str,
help="Name of the torch.hub model (default: unbiased)",
)
parser.add_argument(
"--device",
default="cpu",
type=str,
help="device to load the model on",
)
parser.add_argument(
"--from_ckpt_path",
default=None,
type=str,
help="Option to load from the checkpoint path (default: False)",
)
parser.add_argument(
"--save_to",
default=None,
type=str,
help="destination path to output model results to (default: None)",
)
args = parser.parse_args()
assert args.from_ckpt_path is not None or args.model_name is not None
if args.model_name is not None:
assert args.model_name in [
"original",
"unbiased",
"multilingual",
]
if args.from_ckpt_path is not None and args.model_name is not None:
raise ValueError(
"Please specify only one model source, can either load model from checkpoint path or from model_name."
)
if args.from_ckpt_path is not None:
assert os.path.isfile(args.from_ckpt_path)
run(
args.model_name,
args.input,
args.save_to,
args.from_ckpt_path,
device=args.device,
)