ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
359 stars 34 forks source link

Dispatch Error When Using Quantisation #106

Open Xmaster6y opened 5 months ago

Xmaster6y commented 5 months ago

Description

I am witnessing a dispatch error when using 4bit quantised model. First, note that this is happening when instancitating a LanguageModel from an already existing transformer model in 4bit. Also, note that the 4bit weights should only lie on GPU, and can't go on CPU.

Working Example

from nnsight import LanguageModel

nnsight_model = LanguageModel("gpt2", device_map="auto", load_in_4bit=True)
with nnsight_model.trace('The Eiffel Tower is in the city of') as tracer:
    hidden_states = nnsight_model.transformer.h[0].mlp.act.output[0].clone().save()

Failing Example

from nnsight import LanguageModel
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto", load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
nnsight_model = LanguageModel(model, tokenizer=tokenizer)
with nnsight_model.trace('The Eiffel Tower is in the city of') as tracer:
    hidden_states = nnsight_model.transformer.h[0].mlp.act.output[0].clone().save()

Info

The Error

The error can be found in this illustrative notebook: https://colab.research.google.com/drive/1n9A7MF8JE2lf26e9gOXRi2HaDjl4DjgX?usp=sharing

Xmaster6y commented 5 months ago

[Edit]

In fact, the first method only works with the first call, e.g., the following code fails:

from nnsight import LanguageModel

nnsight_model = LanguageModel("gpt2", device_map="auto", load_in_4bit=True)
with nnsight_model.trace('The Eiffel Tower is in the city of') as tracer:
    hidden_states = nnsight_model.transformer.h[0].mlp.act.output[0].clone().save()
with nnsight_model.trace('The Eiffel Tower is in the city of') as tracer:
    hidden_states = nnsight_model.transformer.h[0].mlp.act.output[0].clone().save()