forked from facebookresearch/SentEval
-
Notifications
You must be signed in to change notification settings - Fork 1
/
tfhub.py
76 lines (59 loc) · 2.1 KB
/
tfhub.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from __future__ import absolute_import, division, unicode_literals
import sys
import os
import logging
import tensorflow as tf
import tensorflow_hub as hub
import json
tf.logging.set_verbosity(tf.logging.INFO)
# Set PATHs
PATH_SENTEVAL = ''
PATH_TO_DATA = './data'
# TF-Hub modules
MODULES = [
'https://tfhub.dev/google/universal-sentence-encoder-large/3',
'https://tfhub.dev/google/Wiki-words-500/1',
'https://tfhub.dev/google/nnlm-en-dim128/1',
'https://tfhub.dev/google/elmo/2',
]
# import senteval
sys.path.insert(0, PATH_SENTEVAL)
import senteval
def prepare(params, samples):
return
def batcher(params, batch):
batch = [' '.join(sent) if sent != [] else '.' for sent in batch]
embeddings = params['module'](batch)
return embeddings
def make_embed_fn(module):
with tf.Graph().as_default():
sentences = tf.placeholder(tf.string)
embed = hub.Module(module)
embeddings = embed(sentences)
session = tf.train.MonitoredSession()
return lambda x: session.run(embeddings, {sentences: x})
"""
Evaluation of trained model on Transfer Tasks (SentEval)
"""
# define senteval params
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
'tenacity': 3, 'epoch_size': 2}
# Set up logger
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.INFO)
if __name__ == "__main__":
# Set up logger
total_results = dict()
for mdl in MODULES:
print("*-----------------------------------------------------------------*")
print("Evaluating module: " + mdl)
print("*-----------------------------------------------------------------*")
module = make_embed_fn(tf.compat.as_str(mdl))
params_senteval['module'] = module
se = senteval.engine.SE(params_senteval, batcher, prepare)
transfer_tasks = ['STS15', 'STS16', 'MR', 'CR', 'SUBJ']
results = se.eval(transfer_tasks)
total_results[tf.compat.as_str(mdl)] = results
print(results)
with open("/tmp/output.json", "w") as f:
json.dump(total_results, f)