Skip to content

Commit

Permalink
update the rnn example
Browse files Browse the repository at this point in the history
  • Loading branch information
nudles committed Apr 14, 2020
1 parent 536f7e4 commit 0748d78
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 112 deletions.
107 changes: 0 additions & 107 deletions examples/rnn/sample.py

This file was deleted.

10 changes: 5 additions & 5 deletions examples/rnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, fpath, batch_size=32, seq_length=100, train_ratio=0.8):
data = [self.char_to_idx[c] for c in self.raw_data]
# seq_length + 1 for the data + label
nsamples = len(data) // (1 + seq_length)
data = data[0:300 * (1 + seq_length)]
data = data[0: nsamples * (1 + seq_length)]
data = np.asarray(data, dtype=np.int32)
data = np.reshape(data, (-1, seq_length + 1))
# shuffle all sequences
Expand Down Expand Up @@ -172,13 +172,13 @@ def sample(model, data, dev, nsamples=100, use_max=False):
y = tensor.softmax(outputs[-1])


def evaluate(model, data, batch_size, seq_length, dev):
def evaluate(model, data, batch_size, seq_length, dev, inputs, labels):
model.eval()
val_loss = 0.0
for b in range(data.num_test_batch):
batch = data.val_dat[b * batch_size:(b + 1) * batch_size]
inputs, labels = convert(batch, batch_size, seq_length, data.vocab_size,
dev)
dev, inputs, labels)
model.reset_states(dev)
y = model(inputs)
loss = model.loss(y, labels)[0]
Expand Down Expand Up @@ -217,8 +217,8 @@ def train(data,
print('\nEpoch %d, train loss is %f' %
(epoch, train_loss / data.num_train_batch / seq_length))

# evaluate(model, data, batch_size, seq_length, cuda, inputs, labels)
# sample(model, data, cuda)
evaluate(model, data, batch_size, seq_length, cuda, inputs, labels)
sample(model, data, cuda)


if __name__ == '__main__':
Expand Down

0 comments on commit 0748d78

Please sign in to comment.