Open ceteri opened 3 years ago
Hello! I believe the argument --dataset-dir <path_to_data>
needs to be specified. So you're right I need to update this so that there is no longer a hard-coded default path that isn't included.
I think you're one of the first people to try out this repo so please keep asking questions and I'll find some time to clean up the repo and add more documentation!
Ok I'm making some updates now. Some changes:
Awesome, thanks @ddehueck ! I'll retry after those commits
@ceteri Ok should be good to go now! You should see embedding like: after around 100 epochs using the world order dataset.
Also, I believe there may be a way to speed up the loss as I did here: https://github.com/ddehueck/CrossWalk/blob/master/domains/sgns_loss.py
If you want to compare to an ML framework without a JIT compiler I have a SGNS implementation in pytorch here: https://github.com/ddehueck/skip-gram-negative-sampling
Many thanks @ddehueck ! The comparison between JAX and PyTorch impls is really helpful.
That ran fine, although the embeddings that I'm seeing reported are:
Learned embeddings:
word: nuclear neighbors: ['michael', 'oxford', 'm', 'ibid', 'p', 'trans', 'quoted', 'n', 'networks', 'data']
word: mankind neighbors: ['michael', 'oxford', 'm', 'ibid', 'p', 'trans', 'quoted', 'n', 'networks', 'data']
word: khomeini neighbors: ['michael', 'oxford', 'm', 'ibid', 'p', 'trans', 'quoted', 'n', 'networks', 'data']
word: ronald neighbors: ['michael', 'oxford', 'm', 'ibid', 'p', 'trans', 'quoted', 'n', 'networks', 'data']
Beginning epoch: 500/500
Looking at line https://github.com/ddehueck/jax-skip-gram-negative-sampling/blob/0696e86ade1d326503355c5a98473932db681b9e/trainer.py#L50 should the set of nearest neighbors be changing for each instance of word
through that iteration?
In the line right above the one you linked to: for word in self.dataset.queries:
which allows each dataset class used to define the queries.
I think what is happening is you are using different data in WorldOrderDataset
class without changing self.dataset.queries
. So the words may not exist in your dataset and you may be looking at a token indicating not in dataset.
So try changing the dataset queries.
Thank you -
Checking through the code, I may be encountering a problem on Trainer.update()
since the params
and g
values are all NaN ?
params [[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
...
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]]
EPOCH: 47 | GRAD MAGNITUDE: nan
This warning gets printed 4x through each epoch:
[W ParallelNative.cpp:206] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
This is with a simple command line:
python train.py --embedding-len 64 --batch-size 2048 --epochs 100
And the dataset is hard-coded to WorldOrderDataset
?
Hmm, not sure how much help I can be without seeing your setup. How is your data pipeline setup?
I seem to be facing the same issue. I just cloned the repo and ran python3 train.py, and it seems my gradient is NaN and I get the exact same embeddings that @ceteri is getting. Were either of you able to find a fix for this?
@arjun-mani Are you running the default sample that is in the codebase or have you added a custom dataset?
@ddehueck Default sample. Seems like the problem was with the loss - maybe you never fixed the bug on remote? I'm getting -inf values for the loss.
@arjun-mani Damn ok will take a closer look
@ddehueck I just fixed the loss locally and am getting embeddings much closer to what you posted. So pretty sure that's the problem :)
@arjun-mani Awesome! Would you mind creating a PR?
Not at all, I'll do it today.
@arjun-mani Appreciate it
@arjun-mani Any chance you can create that PR? Sorry to bother you.
Hey @ddehueck - I'm so so sorry for the delay, it's been a crazy few weeks. I've made a lot of changes personally to the codebase and this is a really small change, so hopefully it's helpful if I just attach the modified bce_loss function (in sgns_loss,.py):
def bce_loss_w_logits(x, y):
max_val = np.clip(x, 0, None)
loss = max_val - x * y + np.log(1 + np.exp(-np.abs(x)))
# loss = x - x * y + max_val + np.log(np.exp(-max_val) + np.exp((-x - max_val)))
return loss.mean()
(The old line is commented). Hope this helps, lmk if you'd like me to clarify anything.
Thank you @arjun-mani that works well.
This code looks really interesting. However, is there a runnable example with the given data?
When running the current code using the example shown:
The hard-coded location
data/
of data does not appear to be included within this repo, and the named datasets are different from what's indatasets/
My apologies, I may have missed something on set up?