piskvorky / gensim

Topic Modelling for Humans
https://radimrehurek.com/gensim
GNU Lesser General Public License v2.1
15.71k stars 4.38k forks source link

Some Doc2Vec vectors remain untrained, with either giant #s of docvecs or small contrived corpuses #2679

Open gojomo opened 5 years ago

gojomo commented 5 years ago

Investigating the user report at https://groups.google.com/d/msg/gensim/XbH5Sr6RBcI/w5-AIwpSAwAJ, I ran with a (much-smaller) version of the synthetic-corpus there, & reproduced similarly inexplicable results, with doc-vectors that should have received at least some training-adjustment showing no change after an epoch of training.

To demonstrate:

import logging
logging.root.setLevel(level=logging.INFO)
from gensim.models.doc2vec import TaggedDocument
from gensim.models import Doc2Vec
import numpy as np
import inspect

class DummyTaggedDocuments(object):

    def __init__(self, count=1001, shared_word=True, doc_word=True, digit_words=False):
        self.count = count
        self.shared_word = shared_word
        self.doc_word = doc_word
        self.digit_words = digit_words

    def __iter__(self):
        for i in range(self.count):
            words = []
            if self.shared_word:
                words += ['shared']
            if self.doc_word: 
                words += ['doc_'+str(i)]
            if self.digit_words:
                words += str(i)  
            tags = [i]
            if i == self.count - 1:
                logging.info("yielding last DummyTaggedDocument %i", i)
            yield TaggedDocument(words=words, tags=tags)

def test_d2v_docvecs_trained(doc_args={}, d2v_args={}):
    """Demo bug in Doc2Vec (& more?) as of gensim 3.8.1 with sample>0"""
    docs = DummyTaggedDocuments(**doc_args)
    d2v_model = Doc2Vec(**d2v_args)

    d2v_model.build_vocab(docs, progress_per=100000)
    starting_vecs = d2v_model.docvecs.vectors_docs.copy()  

    d2v_model.train(docs, total_examples=d2v_model.corpus_count, epochs=d2v_model.epochs, report_delay=5)

    unchanged = np.all(starting_vecs==d2v_model.docvecs.vectors_docs, axis=1)
    unchanged_indexes = np.argwhere(unchanged)

    return (len(unchanged_indexes), list(unchanged_indexes))

test_d2v_docvecs_trained(d2v_args=dict(min_count=0, sample=0.01, vector_size=4, epochs=1, workers=1))

The return value of this test method should, for the given parameters, be a count of 0 unchanged vectors, and and empty-list of unchanged vector indexes. But in a typical run I'm getting instead:

(12,
 [array([0]),
  array([1]),
  array([4]),
  array([6]),
  array([9]),
  array([11]),
  array([13]),
  array([14]),
  array([15]),
  array([30]),
  array([33]),
  array([80])])

Though this is an artificial dataset, with peculiar 2-word documents, which will often only be 1-word documents (after frequent-word-downsampling of the term shared by all documents) – every document should be at least 1 word long, and thus get at least some training in any single epoch. The logging output regarding sampling accurately reports what the effects should be (at this sampling level):

INFO:gensim.models.doc2vec:collected 1002 word types and 1001 unique tags from a corpus of 1001 examples and 2002 words
INFO:gensim.models.word2vec:Loading a fresh vocabulary
INFO:gensim.models.word2vec:effective_min_count=0 retains 1002 unique words (100% of original 1002, drops 0)
INFO:gensim.models.word2vec:effective_min_count=0 leaves 2002 word corpus (100% of original 2002, drops 0)
INFO:gensim.models.word2vec:deleting the raw counts dictionary of 1002 items
INFO:gensim.models.word2vec:sample=0.01 downsamples 1 most-common words
INFO:gensim.models.word2vec:downsampling leaves estimated 1162 word corpus (58.1% of prior 2002)

I get similar evidence of unchanged vectors in dm=0 (PV-DBOW) mode. However, running more epochs usually drives the number of unchanged vectors to 0.

Turning off downsampling with sample=0 ensures all vectors show some update, implying some error in the downsampling is involved. Essentially, that strongly implies something going wrong in the code at or around:

https://github.com/RaRe-Technologies/gensim/blob/3d6596112f8f1fc0e839a32c5a00ef3d7365c264/gensim/models/doc2vec_inner.pyx#L344

But, a manual check of the precalculated sample_int values for all-but-the-most-frequent word suggests they're where they should be: a value that the random-int is never higher-than, and thus a value that should result in the corresponding words never being down-sampled.

I may not have time to dig deeper anytime soon, so placing this recipe-to-reproduce & key observations so far here.

Notably, Word2Vec & FastText may be using similar sampling logic – so even though there's not yet a sighting there, similar issues may exist.

(Separately, I somewhat doubt this sample-related anomaly, whatever its cause, is necessarily related to actual original problem of the user in the thread referenced – which seemed only present in extremely large corpuses, likely with real many-word texts, over a normal number of repeated epochs, and perhaps only in the "very tail-end" doc-vectors.)

gojomo commented 5 years ago

To the extent it might be possible to double-check the randomization that's happening in cython, I ran the following cells in a notebook - which seem to indicate random_int32(&c.next_random) is not malfunctioning to be greater-than a sample_int value of 2**32 (the lowest value any of the never-downsampled, less-frequent words appear to have).

%load_ext Cython
%%cython
from gensim.models.word2vec_inner cimport random_int32
import logging
cimport numpy as np
from numpy import random
RS = random.RandomState(1)
cdef unsigned long long next_random = (2**24) * RS.randint(0, 2**24) + RS.randint(0, 2**24)
cpdef unsigned long long cyrand():
    return random_int32(&next_random)
cpdef unsigned long long show_nr():
    return next_random
for _ in range(1000000000):
    if (cyrand() >= (2**32)):
        print("whoops")  # not seen

This somewhat decreases my confidence that the sample value is inherently involved: maybe it's just shrinking the corpus enough so that some other stochastic error becomes detectable. (For example, maybe the real bug is that the wrong vector/memory is often being updated; more epochs or simply the ~2/3-bigger corpus mean eventually the misfires still change every vector.)

In fact, while I'd early thought all my tests with sample=0 passed, in fact there are anomalies there, too. For example, consider this test where the dummy docs each have a single, unique word, which should always be available for training with no downsampling:

>>> test_d2v_docvecs_trained(doc_args=dict(shared_word=False, count=1001), d2v_args=dict(min_count=0, sample=0, vector_size=4, epochs=1, workers=1, dm=0, seed=42))
.... (25,
 [array([0]),
  array([1]),
  array([2]),
  array([3]),
  array([4]),
  array([6]),
  array([7]),
  array([8]),
  array([10]),
  array([12]),
  array([13]),
  array([14]),
  array([15]),
  array([18]),
  array([19]),
  array([20]),
  array([21]),
  array([22]),
  array([24]),
  array([27]),
  array([28]),
  array([39]),
  array([41]),
  array([53]),
  array([61])])

This result is especially odd in that all the unadjusted vectors are so early in the 1001.

Another anomaly, when training with a degenerate corpus of documents that are all just the same single word (with no downsampling):

>>> test_d2v_docvecs_trained(doc_args=dict(doc_word=False, count=1001), d2v_args=dict(min_count=0, sample=0, vector_size=4, epochs=1, workers=1, dm=0, seed=42))
... (1, [array([0])])

In such an odd corpus, there are never any true 'negative' examples, because any random word chosen will be the same as the target-word. But the positive-training should still happen, and seems to happen for the last 1000 documents - but not document #0.

So lots of oddities here needing deeper investigation.

gojomo commented 5 years ago

Per original reporter's followup at https://groups.google.com/d/msg/gensim/XbH5Sr6RBcI/fX1yp4LFBwAJ, in that large-number-of-docvecs test, the untrained vectors are all just past the positions in model.docvecs.vectors_docs that could be reached at a 2^32 offset from its origin.

To excerpt that here in case the Google Drive shares go away:

def test_d2v_docvecs_trained(doc_args={}, d2v_args={}):
    """Demo bug in Doc2Vec (& more?) as of gensim 3.8.1 with sample>0"""
    docs = DummyTaggedDocuments(**doc_args)
    d2v_model = Doc2Vec(**d2v_args)

    d2v_model.build_vocab(docs, progress_per=100000)
    starting_vecs = d2v_model.docvecs.vectors_docs.copy()

    d2v_model.train(docs, total_examples=d2v_model.corpus_count, epochs=d2v_model.epochs, report_delay=5)

    unchanged = np.all(starting_vecs == d2v_model.docvecs.vectors_docs, axis=1)
    unchanged_indexes = np.argwhere(unchanged)

    return (len(unchanged_indexes), list(unchanged_indexes))

if __name__ == "__main__":
    len_unchanged,list_unchanged = test_d2v_docvecs_trained(doc_args=dict(doc_word=False, count=8590000), d2v_args=dict(min_count=0, sample=0, vector_size=500, epochs=5, workers=4))
    logging.info("unchanged %i", len_unchanged)
    logging.info("unchanged %s", list_unchanged)

...& log output...

INFO:root:yielding last DummyTaggedDocument 8589999
INFO:gensim.models.doc2vec:collected 1 word types and 8590000 unique tags from a corpus of 8590000 examples and 8590000 words
INFO:gensim.models.word2vec:Loading a fresh vocabulary
INFO:gensim.models.word2vec:min_count=0 retains 1 unique words (100% of original 1, drops 0)
INFO:gensim.models.word2vec:min_count=0 leaves 8590000 word corpus (100% of original 8590000, drops 0)
INFO:gensim.models.word2vec:deleting the raw counts dictionary of 1 items
INFO:gensim.models.word2vec:sample=0 downsamples 0 most-common words
INFO:gensim.models.word2vec:downsampling leaves estimated 8590000 word corpus (100.0% of prior 8590000)
INFO:gensim.models.base_any2vec:estimated required memory for 1 words and 500 dimensions: 17180004500 bytes
INFO:gensim.models.word2vec:resetting layer weights
INFO:gensim.models.base_any2vec:training model with 4 workers on 1 vocabulary and 500 features, using sg=0 hs=0 sample=0 negative=5 window=5
INFO:gensim.models.base_any2vec:EPOCH 1 - PROGRESS: at 0.12% examples, 9683 words/s, in_qsize 7, out_qsize 0
...
...
INFO:gensim.models.base_any2vec:EPOCH 5 - PROGRESS: at 99.53% examples, 31993 words/s, in_qsize 4, out_qsize 0
INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 3 more threads
INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 2 more threads
INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 1 more threads
INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 0 more threads
INFO:gensim.models.base_any2vec:EPOCH - 5 : training on 8590000 raw words (17180000 effective words) took 535.3s, 32095 effective words/s
INFO:gensim.models.base_any2vec:training on a 42950000 raw words (85900000 effective words) took 2582.9s, 33257 effective words/s
INFO:root:unchanged 65
INFO:root:unchanged [array([8589935]), array([8589936]), array([8589937]), array([8589938]), array([8589939]), array([8589940]), array([8589941]), array([8589942]), array([8589943]), array([8589944]), array([8589945]), array([8589946]), array([8589947]), array([8589948]), array([8589949]), array([8589950]), array([8589951]), array([8589952]), array([8589953]), array([8589954]), array([8589955]), array([8589956]), array([8589957]), array([8589958]), array([8589959]), array([8589960]), array([8589961]), array([8589962]), array([8589963]), array([8589964]), array([8589965]), array([8589966]), array([8589967]), array([8589968]), array([8589969]), array([8589970]), array([8589971]), array([8589972]), array([8589973]), array([8589974]), array([8589975]), array([8589976]), array([8589977]), array([8589978]), array([8589979]), array([8589980]), array([8589981]), array([8589982]), array([8589983]), array([8589984]), array([8589985]), array([8589986]), array([8589987]), array([8589988]), array([8589989]), array([8589990]), array([8589991]), array([8589992]), array([8589993]), array([8589994]), array([8589995]), array([8589996]), array([8589997]), array([8589998]), array([8589999])]

This is strongly suggestive of a problem related to 32-bit indexing in the cython code – though in a quick glance, it appears to me that 64-bit values are being used consistently.

Because of the way such overflows may 'wraparound', it's very possible that not only are some vectors remaining untrained, but other vectors (early in the range) are also receiving misdirected updates.

Since the demonstrations above use tiny contrived datasets far from the 2^32 threshold, that's in all likelihood an unrelated bug – so there's probably 2 very separate problems here. However, as they were discovered together & will require the same careful stepwise analysis of much of the same code, I'll leave them both covered by this one issue unless/until someone needs to break them apart.

I suppose that with the 'large' problem, there's some small chance the problem is introduced by the wheels-build, or OS/libraries of the user system, even if the cython code is correct when compiled in the appropriate environment.

gojomo commented 5 years ago

It'd be interesting to try to reproduce in a Word2Vec model, potentially with a similarly synthetic corpus (with 8700000+ unique words of 500d each), as a quick check of how general the problem might be. (I believe that'd require a machine with much more than 32GB RAM.)

I'm not likely to have much time to dig deep into this in the near future, so if urgent to fix, someone else will need to take the lead.