microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
34.79k stars 4.05k forks source link

[BUG] Non-deterministic model output if `replace_with_kernel_inject=True` for GPT-neo-1.3B/Bloom 1b7 and others #3404

Open trianxy opened 1 year ago

trianxy commented 1 year ago

Describe the bug For the models GPT-neo-1.3B, Bloom 1b7, Pythia 1.4b, GPT2-xl, I get non-deterministic model outputs when using context length 1 and engine = deepspeed.init_inference(model, dtype=torch.float16, replace_with_kernel_inject=True).

The context length 1 may let this sound like a low-priority bug, BUT it may be not: When using transformer's model.generate, the context length of the ids may be cut down to 1 (because of the use of past_key_values to speed up model inference). In particular, the above bug is a blocker to me rewriting transformer's model.generate for my needs.

To Reproduce

# !pip install --upgrade torch==1.13.1
# !pip install --upgrade transformers==4.28.1
# !pip install --upgrade deepspeed==0.9.1

SEED = 42

from typing import Any

import random
random.seed(SEED)

import numpy as np
np.random.seed(SEED)

import torch
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import deepspeed
from transformers import AutoTokenizer, AutoModelForCausalLM

ARCHITECTURE = "EleutherAI/gpt-neo-1.3B"  # error -> non-deterministic
# ARCHITECTURE = "EleutherAI/pythia-1.4b"  # error -> non-deterministic
# ARCHITECTURE = "bigscience/bloom-1b7"  # error -> non-deterministic
# ARCHITECTURE = "gpt2-xl"  # error -> slightly non-deterministic

DEVICE = "cuda"

model = AutoModelForCausalLM.from_pretrained(ARCHITECTURE).to(DEVICE).eval()
tokenizer = AutoTokenizer.from_pretrained(ARCHITECTURE, use_fast=True)

def test_if_model_is_deterministic(mod: Any, tok: Any, device: str) -> None:
    mod = mod.eval()
    with torch.inference_mode():
        throw_away_ids = tok.encode(" 1 2 3 4 5 6 7", return_tensors="pt").to(device)
        throw_away_output = mod(throw_away_ids)

        for _ in range(10):
            ids = tok.encode(" 4", return_tensors="pt").to(device)
            output = mod(ids)
            token_id = torch.argmax(output.logits[0][-1]).item()
            logit = output.logits[0][-1][token_id].item()
            token = tok.decode(token_id)
            print(f"{token=}, {token_id=}, {logit=}")

test_if_model_is_deterministic(mod=model, tok=tokenizer, device=DEVICE) 
# prints:
# token='.', token_id=13, logit=-0.9108065366744995
# token='.', token_id=13, logit=-0.9108065366744995
# token='.', token_id=13, logit=-0.9108065366744995
# token='.', token_id=13, logit=-0.9108065366744995
# token='.', token_id=13, logit=-0.9108065366744995
# token='.', token_id=13, logit=-0.9108065366744995
# token='.', token_id=13, logit=-0.9108065366744995
# token='.', token_id=13, logit=-0.9108065366744995
# token='.', token_id=13, logit=-0.9108065366744995
# token='.', token_id=13, logit=-0.9108065366744995

engine = deepspeed.init_inference(model, dtype=torch.float16, replace_with_kernel_inject=True)

test_if_model_is_deterministic(engine.module, tokenizer, device=DEVICE) 
# prints:
# token=' 5', token_id=642, logit=-2.376953125
# token=' 4', token_id=604, logit=-1.349609375
# token=' 4', token_id=604, logit=-0.7822265625
# token=' 4', token_id=604, logit=-0.662109375
# token=' 4', token_id=604, logit=-0.59765625
# token=' 4', token_id=604, logit=0.01448822021484375
# token='\n', token_id=198, logit=-3.369140625
# token='t', token_id=83, logit=-5.5
# token='t', token_id=83, logit=-7.88671875
# token='l', token_id=75, logit=-11.1484375

Expected behavior The 10 print outs after test_if_model_is_deterministic(engine.module, tokenizer, device=DEVICE) should not change. They should be identical to the print outs after test_if_model_is_deterministic(engine.module, tokenizer, device=DEVICE), but they are not.

ds_report output

(pytorch_p39) sh-4.2$ ds_report
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch']
torch version .................... 1.13.1
deepspeed install path ........... ['/home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/deepspeed']
deepspeed info ................... 0.9.1, unknown, unknown
torch cuda version ............... 11.7
torch hip version ................ None
nvcc version ..................... 11.7
deepspeed wheel compiled w. ...... torch 1.13, cuda 11.7

System info (please complete the following information):

trianxy commented 1 year ago

While debugging the above, I stumbled upon this and this code line which seem to be related. There, if context length is NOT 1, self.layer_past is not reset and - possibly - some state of an earlier and unrelated model output will be used!

I am happy to debug this further but would need some pointers why those lines were added. Maybe you, @awan-10 , can help me debug this further? (I see that you may have worked around the above code lines in the past)

 

Also, I am probably missing some insights:

Upon watching layer_past travel across several methods, I see that it ends up in DeepSpeedAttention.compute_attention(...) but it is not used inside that function. Perhaps the latter is overridden, but I don't know with what/how.

trianxy commented 1 year ago

Based on this issue, the above non-deterministic behavior happens, because DeepSpeed assumes when you input 1 token id, that you are looking for a continuation of what was inputted before. And it uses internal past cache for that.

Here is a more vivid example of this bug. Check how the model predicts the token 5, apparently assuming that you want it to predict the next token after the prompt 5 10 5 10, but without us inputting that prompt or its past key values:

from typing import Optional, Any
import torch
import deepspeed
from transformers import AutoTokenizer, AutoModelForCausalLM
ARCHITECTURE = "gpt2"
model = AutoModelForCausalLM.from_pretrained(ARCHITECTURE).to("cuda").eval()
tokenizer = AutoTokenizer.from_pretrained(ARCHITECTURE, use_fast=True)
engine = deepspeed.init_inference(model, dtype=torch.float16, replace_with_kernel_inject=True)
model = engine.module

def print_next_token_after_prompt(prompt: str, pkv: Optional[tuple]) -> None:
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    output = model(input_ids=inputs["input_ids"], past_key_values=pkv)
    token_id = torch.argmax(output.logits[0][-1])
    token = tokenizer.decode(token_id)
    print(f"{prompt=} -> {token=}")

print_next_token_after_prompt(prompt=" 5 10 5", pkv=None)  # prompt=' 5 10 5' -> token=' 10'
print_next_token_after_prompt(prompt=" 10", pkv=None)  # prompt=' 10' -> token=','

empty_past_key_values = model.config.n_layer * ((torch.Tensor([[[]]]), torch.Tensor([[[]]])))
print_next_token_after_prompt(prompt=" 5 10 5", pkv=None)  # prompt=' 5 10 5' -> token=' 10'
print_next_token_after_prompt(prompt=" 10", pkv=empty_past_key_values)  # prompt=' 10' -> token=' 5'
awan-10 commented 12 months ago

@trianxy - Thanks for tagging me. I think the only person who can explain this will be @RezaYazdaniAminabadi :)