-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
27 lines (22 loc) · 1.17 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from rnn import main as rnn_main
from ffnn import main as ffnn_main
import argparse
parser = argparse.ArgumentParser(description='Model and Parameter Selection')
parser.add_argument('name', type=str, help='Name of the model for saving after training')
parser.add_argument('model', type=str, help='Model to run, either "RNN" or "FFNN"')
parser.add_argument('--embedding', type=int, default=64, help='Embedding dimension size')
parser.add_argument('--hidden', type=int, default=32, help='Hiddem dimension size')
parser.add_argument('--layers', type=int, default=1, help='Number of hidden layers')
parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
parser.add_argument('--RNNcore', type=bool, default=False, help='Whether to use RNN as core or LSTM')
args = parser.parse_args()
def main():
if args.model == 'RNN':
rnn_main(args.name, args.embedding, args.hidden, args.layers, args.epochs, args.RNNcore)
elif args.model == 'FFNN':
ffnn_main(args.name, hidden_dim=args.hidden, number_of_epochs=args.epochs, n_layers=args.layers)
else:
print('Incompatible model declaration')
return
if __name__ == '__main__':
main()