ddehueck / jax-skip-gram-negative-sampling

A Jax implementation of word2vec's skip-gram model with negative sampling as described in Mikolov et al., 2013
MIT License
9 stars 1 forks source link

Is there a runnable example with the given data? #1

Open ceteri opened 3 years ago

ceteri commented 3 years ago

This code looks really interesting. However, is there a runnable example with the given data?

When running the current code using the example shown:

python train.py --embedding-len 64 --batch-size 2048 --epochs 500

The hard-coded location data/ of data does not appear to be included within this repo, and the named datasets are different from what's in datasets/

$ python train.py --embedding-len 64 --batch-size 2048 --epochs 500
Traceback (most recent call last):
  File "train.py", line 57, in <module>
    trainer = Trainer(args)
  File "/Users/paco/src/jax-skip-gram-negative-sampling/trainer.py", line 18, in __init__
    self.dataset = PyPILangDataset(args)#, examples_path='data/pypi_examples.pth', dict_path='data/pypi_dict.pth')
  File "/Users/paco/src/jax-skip-gram-negative-sampling/datasets/pypi_lang.py", line 18, in __init__
    self.files = self.tokenize_files()
  File "/Users/paco/src/jax-skip-gram-negative-sampling/datasets/pypi_lang.py", line 29, in tokenize_files
    node_lang_df = pd.read_csv(self.args.dataset_dir, na_filter=False)
  File "/opt/anaconda3/lib/python3.7/site-packages/pandas/io/parsers.py", line 688, in read_csv
    return _read(filepath_or_buffer, kwds)
  File "/opt/anaconda3/lib/python3.7/site-packages/pandas/io/parsers.py", line 454, in _read
    parser = TextFileReader(fp_or_buf, **kwds)
  File "/opt/anaconda3/lib/python3.7/site-packages/pandas/io/parsers.py", line 948, in __init__
    self._make_engine(self.engine)
  File "/opt/anaconda3/lib/python3.7/site-packages/pandas/io/parsers.py", line 1180, in _make_engine
    self._engine = CParserWrapper(self.f, **self.options)
  File "/opt/anaconda3/lib/python3.7/site-packages/pandas/io/parsers.py", line 2010, in __init__
    self._reader = parsers.TextReader(src, **kwds)
  File "pandas/_libs/parsers.pyx", line 382, in pandas._libs.parsers.TextReader.__cinit__
  File "pandas/_libs/parsers.pyx", line 674, in pandas._libs.parsers.TextReader._setup_parser_source
FileNotFoundError: [Errno 2] No such file or directory: 'data/'

My apologies, I may have missed something on set up?

ddehueck commented 3 years ago

Hello! I believe the argument --dataset-dir <path_to_data> needs to be specified. So you're right I need to update this so that there is no longer a hard-coded default path that isn't included.

I think you're one of the first people to try out this repo so please keep asking questions and I'll find some time to clean up the repo and add more documentation!

ddehueck commented 3 years ago

Ok I'm making some updates now. Some changes:

  1. Dataset Class needs to be loaded in trainer.py
  2. There's a bug with the loss function - I had a fix for it in a local dir on my machine that never got commited.
  3. data directory should be added. I'll commit the world order example as the default.
ceteri commented 3 years ago

Awesome, thanks @ddehueck ! I'll retry after those commits

ddehueck commented 3 years ago

@ceteri Ok should be good to go now! You should see embedding like: after around 100 epochs using the world order dataset.

ddehueck commented 3 years ago

Also, I believe there may be a way to speed up the loss as I did here: https://github.com/ddehueck/CrossWalk/blob/master/domains/sgns_loss.py

If you want to compare to an ML framework without a JIT compiler I have a SGNS implementation in pytorch here: https://github.com/ddehueck/skip-gram-negative-sampling

ceteri commented 3 years ago

Many thanks @ddehueck ! The comparison between JAX and PyTorch impls is really helpful.

That ran fine, although the embeddings that I'm seeing reported are:

Learned embeddings:
word: nuclear neighbors: ['michael', 'oxford', 'm', 'ibid', 'p', 'trans', 'quoted', 'n', 'networks', 'data']
word: mankind neighbors: ['michael', 'oxford', 'm', 'ibid', 'p', 'trans', 'quoted', 'n', 'networks', 'data']
word: khomeini neighbors: ['michael', 'oxford', 'm', 'ibid', 'p', 'trans', 'quoted', 'n', 'networks', 'data']
word: ronald neighbors: ['michael', 'oxford', 'm', 'ibid', 'p', 'trans', 'quoted', 'n', 'networks', 'data']
Beginning epoch: 500/500

Looking at line https://github.com/ddehueck/jax-skip-gram-negative-sampling/blob/0696e86ade1d326503355c5a98473932db681b9e/trainer.py#L50 should the set of nearest neighbors be changing for each instance of word through that iteration?

ddehueck commented 3 years ago

In the line right above the one you linked to: for word in self.dataset.queries: which allows each dataset class used to define the queries.

I think what is happening is you are using different data in WorldOrderDataset class without changing self.dataset.queries. So the words may not exist in your dataset and you may be looking at a token indicating not in dataset.

So try changing the dataset queries.

ceteri commented 3 years ago

Thank you -

Checking through the code, I may be encountering a problem on Trainer.update() since the params and g values are all NaN ?

params [[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
EPOCH: 47 | GRAD MAGNITUDE: nan

This warning gets printed 4x through each epoch:

[W ParallelNative.cpp:206] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)

This is with a simple command line:

python train.py --embedding-len 64 --batch-size 2048 --epochs 100

And the dataset is hard-coded to WorldOrderDataset ?

ddehueck commented 3 years ago

Hmm, not sure how much help I can be without seeing your setup. How is your data pipeline setup?

arjun-mani commented 3 years ago

I seem to be facing the same issue. I just cloned the repo and ran python3 train.py, and it seems my gradient is NaN and I get the exact same embeddings that @ceteri is getting. Were either of you able to find a fix for this?

ddehueck commented 3 years ago

@arjun-mani Are you running the default sample that is in the codebase or have you added a custom dataset?

arjun-mani commented 3 years ago

@ddehueck Default sample. Seems like the problem was with the loss - maybe you never fixed the bug on remote? I'm getting -inf values for the loss.

ddehueck commented 3 years ago

@arjun-mani Damn ok will take a closer look

arjun-mani commented 3 years ago

@ddehueck I just fixed the loss locally and am getting embeddings much closer to what you posted. So pretty sure that's the problem :)

ddehueck commented 3 years ago

@arjun-mani Awesome! Would you mind creating a PR?

arjun-mani commented 3 years ago

Not at all, I'll do it today.

ddehueck commented 3 years ago

@arjun-mani Appreciate it

ddehueck commented 3 years ago

@arjun-mani Any chance you can create that PR? Sorry to bother you.

arjun-mani commented 3 years ago

Hey @ddehueck - I'm so so sorry for the delay, it's been a crazy few weeks. I've made a lot of changes personally to the codebase and this is a really small change, so hopefully it's helpful if I just attach the modified bce_loss function (in sgns_loss,.py):

def bce_loss_w_logits(x, y):
        max_val = np.clip(x, 0, None)
        loss = max_val - x * y + np.log(1 + np.exp(-np.abs(x)))
        # loss = x - x * y + max_val + np.log(np.exp(-max_val) + np.exp((-x - max_val)))                                                                      
        return loss.mean()

(The old line is commented). Hope this helps, lmk if you'd like me to clarify anything.

ceteri commented 3 years ago

Thank you @arjun-mani that works well.