sherjilozair / char-rnn-tensorflow

Multi-layer Recurrent Neural Networks (LSTM, RNN) for character-level language models in Python using Tensorflow
MIT License
2.64k stars 960 forks source link

weighted_pick() can return invalid index #18

Open wohnjayne opened 8 years ago

wohnjayne commented 8 years ago

weighted_pick(weights) in model.py can return an index which is larger than len(chars)-1

this happens if sum(weights)<1, and at the same time np.random.rand(1)>sum(weights)

then, int(np.searchsorted(t, np.random.rand(1)*s) )==len(t), which leads to an IndexError

import numpy as np
p=np.array([ 0.1, 0.2, 0.699 ], dtype=np.float32)
t = np.cumsum(p)
s = np.sum(p)
randval=0.9999
print int(np.searchsorted(t, randval)) # gives 3, which is too large, as len(t)==3

so probably numpy.random.choice() is the better choice, despite being slower