explosion / spaCy

💫 Industrial-strength Natural Language Processing (NLP) in Python
https://spacy.io
MIT License
30.3k stars 4.41k forks source link

Models are not deterministic / reproducible on GPU #6490

Open echatzikyriakidis opened 3 years ago

echatzikyriakidis commented 3 years ago

How to reproduce the behaviour

I cannot reproduce the same results when training a NER model using GPU in Google Colab. When running the same code with CPU it seems to work. However, when enabling GPU with prefer_gpu() the reproduction is not working.

`

Example code

def train_blank_ner_model(language_id, train_X, entity_types, epochs, random_state, dropout, minibatch_size, losses_display_frequency_in_epochs):
    fix_random_seed(random_state)

    nlp = spacy.blank(language_id)

    assert len(nlp.pipe_names) == 0, f"Pipeline of blank model '{language_id}' is not empty."

    ner = nlp.create_pipe('ner')

    nlp.add_pipe(ner)

    for entity_type in entity_types:
        ner.add_label(entity_type)

    optimizer = nlp.begin_training()

    for epoch in tqdm(range(1, epochs + 1)):
        train_X = shuffle(train_X, random_state=random_state)

        losses = {}

        batches = minibatch(train_X, size=compounding(*minibatch_size))

        for batch in tqdm(batches, leave=False):
            texts, annotations = zip(*batch)

            nlp.update(texts, annotations, sgd=optimizer, drop=dropout, losses=losses)

        if epoch % losses_display_frequency_in_epochs == 0:
            print(f"Epoch {epoch}, Loss: {losses['ner']}")

    print(f"Training completed with loss: {losses['ner']}")

    return nlp

import spacy

print(f"GPU Initialization: {spacy.prefer_gpu()}")

nlp = train_blank_ner_model(language_id='de',
        train_X=X_train,
        entity_types=ner_entity_types,
        epochs=3,
        random_state=42,
        dropout=0.4,
        minibatch_size=(0.4, 0.4, 1.0),
        losses_display_frequency_in_epochs=5)

`

Your Environment

svlandeg commented 3 years ago

Thanks for the report! I just double checked with the latest code from master and can confirm that there seems to be a reproducibility issue for the GPU when training the NER model.

We'll look into this!

echatzikyriakidis commented 3 years ago

Thank you @svlandeg !

We can continue our experimentation phase even without determinism since the losses from various runs with different random seeds are more or loss the same. No big flunctuations.

However, if we can have soon a new release with the fix it could be so great.

Please note that the same thing happens when using a pre-trained model, etc, en_core_web_lg.

echatzikyriakidis commented 3 years ago

Hi @svlandeg !

Do we have any update on this?

polm commented 3 years ago

I managed to track down the source of this problem. In the backprop in HashEmbed we use cupyx.scatter_add, which is non-deterministic. So this affects anything that uses a tok2vec layer.

Unfortunately there is not a simple substitution for this without consequences. We could unroll the addition to control the order of operations but it would be too slow. This is also known to be an issue in Pytorch (which doesn't use cupy but a similar implementation) but because the actual change in values is small it's not generally considered an issue (see https://github.com/pytorch/pytorch/issues/50469).

That said we think we can design a deterministic equivalent with a more acceptable speed penalty and will be taking a look at it. In the meantime this is something to be aware of, and this will be the main issue for it, so just subscribe here if you'd like updates.