piskvorky / gensim

Topic Modelling for Humans
https://radimrehurek.com/gensim
GNU Lesser General Public License v2.1
15.71k stars 4.38k forks source link

Word2vec: total loss suspiciously drops with worker count, probably thread-unsafe tallying #2743

Open tsaastam opened 4 years ago

tsaastam commented 4 years ago

Problem description

The word2vec implementation requires a workaround, as detailed in #2735, to correctly report the total loss per epoch. After doing that though, the next issue is that the total loss reported seems to vary depending on the number of workers.

Steps/code/corpus to reproduce

This is my code:

class MyLossCalculatorII(CallbackAny2Vec):
    def __init__(self):
        self.epoch = 1
        self.losses = []
        self.cumu_loss = 0.0
        self.previous_epoch_time = time.time()

    def on_epoch_end(self, model):
        loss = model.get_latest_training_loss()
        norms = [linalg.norm(v) for v in model.wv.vectors]
        now = time.time()
        epoch_seconds = now - self.previous_epoch_time
        self.previous_epoch_time = now
        self.cumu_loss += float(loss)
        print(f"Loss after epoch {self.epoch}: {loss} (cumulative loss so far: {self.cumu_loss}) "+\
              f"-> epoch took {round(epoch_seconds, 2)} s - vector norms min/avg/max: "+\
              f"{round(float(min(norms)), 2)}, {round(float(sum(norms)/len(norms)), 2)}, {round(float(max(norms)), 2)}")
        self.epoch += 1
        self.losses.append(float(loss))
        model.running_training_loss = 0.0

def train_and_check(my_sentences, my_epochs, my_workers=8, my_loss_calc_class=MyLossCalculatorII):
    print(f"Building vocab...")
    my_model: Word2Vec = Word2Vec(sg=1, compute_loss=True, workers=my_workers)
    my_model.build_vocab(my_sentences)
    print(f"Vocab done. Training model for {my_epochs} epochs, with {my_workers} workers...")
    loss_calc = my_loss_calc_class()
    trained_word_count, raw_word_count = my_model.train(my_sentences, total_examples=my_model.corpus_count, compute_loss=True,
                                                        epochs=my_epochs, callbacks=[loss_calc])
    loss = loss_calc.losses[-1]
    print(trained_word_count, raw_word_count, loss)
    loss_df = pd.DataFrame({"training loss": loss_calc.losses})
    loss_df.plot(color="blue")
#    print(f"Calculating accuracy...")
#    acc, details = my_model.wv.evaluate_word_analogies(questions_file, case_insensitive=True)
#    print(acc)
    return loss_calc, my_model

My data is an in-memory list of sentences of Finnish text, each sentence being a list of strings:

[18]: sentences[0]
[18]: ['hän', 'tietää', 'minkälainen', 'tilanne', 'tulla']

I'm running the following code:

lc4, model4 = train_and_check(sentences, my_epochs=20, my_workers=4)
lc8, model8 = train_and_check(sentences, my_epochs=20, my_workers=8)
lc16, model16 = train_and_check(sentences, my_epochs=20, my_workers=16)
lc32, model32 = train_and_check(sentences, my_epochs=20, my_workers=32)

And the outputs are (last few lines + plot only):

# lc4
Loss after epoch 20: 40341580.0 (cumulative loss so far: 830458060.0) -> epoch took 58.15 s - vector norms min/avg/max: 0.02, 3.79, 12.27
589841037 669998240 40341580.0
Wall time: 20min 14s

lc4

# lc8
Loss after epoch 20: 25501282.0 (cumulative loss so far: 521681620.0) -> epoch took 36.6 s - vector norms min/avg/max: 0.02, 3.79, 12.24
589845960 669998240 25501282.0
Wall time: 12min 46s

lc8

# lc16
Loss after epoch 20: 14466763.0 (cumulative loss so far: 295212011.0) -> epoch took 26.25 s - vector norms min/avg/max: 0.02, 3.79, 12.55
589839763 669998240 14466763.0
Wall time: 9min 35s

lc16

# lc32
Loss after epoch 20: 7991086.5 (cumulative loss so far: 161415654.5) -> epoch took 27.5 s - vector norms min/avg/max: 0.02, 3.79, 12.33
589843184 669998240 7991086.5
Wall time: 9min 37s

lc32

What is going on here? The loss (whether total loss, final-epoch loss or average loss per epoch) varies, although the data is the same and the number of epochs is the same. I would imagine that "1 epoch" means "each data point is considered precisely once", in which case the number of workers should only affect how quickly the training is done and not the loss (the loss would still vary randomly a bit depending on which order the data points are considered etc, but that should be minor). Here though the loss seems to be roughly proportional to 1/n where n = number of workers.

I'm guessing based on the similar shape of the loss progressions and the very similar vector magnitudes that the training is actually fine in all four cases, so hopefully this is just another display bug similar to #2735.

Versions

The output of

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import gensim; print("gensim", gensim.__version__)
from gensim.models import word2vec;print("FAST_VERSION", word2vec.FAST_VERSION)

is

Windows-10-10.0.18362-SP0
Python 3.7.3 | packaged by conda-forge | (default, Jul  1 2019, 22:01:29) [MSC v.1900 64 bit (AMD64)]
NumPy 1.17.3
SciPy 1.3.1
gensim 3.8.1
FAST_VERSION 1
gojomo commented 4 years ago

As soon as you use more than one thread, the order of training examples will vary based on scheduling jitter from the OS. And the progression of random-choices used by the algorithm will vary. So you wouldn't necessarily expect the tallied loss values, at the end of any epoch or all training, to be identical or closely correlated.

Further, some have observed that stochastic-gradient-descent where multiple parallel sessions are sometimes clobbering each others' results may surprisingly work a bit better than pure, synchronous SGD. See for example https://cxwangyi.wordpress.com/2013/04/09/why-asynchronous-sgd-works-better-than-its-synchronous-counterpart/. So that might explain a somewhat "faster" improvement in loss in multithreaded situations.

However it's also quite likely the loss-calculation code, bolted on later & never really fully tested, implemented for all related classes (FastText, Doc2Vec), or verified as being what users needed isn't doing the right thing in multithreaded situations, with some tallies being lost when multiple threads update the same value. (In particular, the way the Cython code copies the running value into a C-optimized structure, tallies it there, then copies it back to the shared location could very well lead to many updates being lost. The whole feature needs a competent revisit, see #2617.)