TransformerLensOrg / TransformerLens

A library for mechanistic interpretability of GPT-style language models
https://transformerlensorg.github.io/TransformerLens/
MIT License
1.43k stars 275 forks source link

[Bug Report] Load model to mutilple devices #439

Open liuxin99 opened 11 months ago

liuxin99 commented 11 months ago

When I use the following code to load LLama2 and generate:

model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", 
                                              hf_model=hf_model, 
                                              device="cuda", 
                                              n_devices=4, 
                                              move_to_device=True,
                                              fold_ln=False, 
                                              center_writing_weights=False, 
                                              center_unembed=False, 
                                              tokenizer=tokenizer)
model.generate("The capital of Germany is", max_new_tokens=20, temperature=0)

I got an error:

    mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)

I found that when I load model to multiple devices, the Attention mask matrix is always on cuda 0, which raised the abovementioned error. So, I have made the following change in the forward function of the attention module :

   def forward(
        self,
        query_input: Union[
            Float[torch.Tensor, "batch pos d_model"],
            Float[torch.Tensor, "batch pos head_index d_model"],
        ],
        key_input: Union[
            Float[torch.Tensor, "batch pos d_model"],
            Float[torch.Tensor, "batch pos head_index d_model"],
        ],
        value_input: Union[
            Float[torch.Tensor, "batch pos d_model"],
            Float[torch.Tensor, "batch pos head_index d_model"],
        ],
        past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
        additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None,
        attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
    ) -> Float[torch.Tensor, "batch pos d_model"]:
        """
        shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
        past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
        additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
        attention_mask is the attention mask for padded tokens. Defaults to None.
        """
        # move the attention mask to the device that the attention block is on 
        if additive_attention_mask is not None and additive_attention_mask.device != self.rotary_sin.device:
            additive_attention_mask = additive_attention_mask.to(self.rotary_sin.device)
        if attention_mask is not None and attention_mask.device != self.rotary_sin.device:
            attention_mask = attention_mask.to(self.rotary_sin.device)

and it works well.

alan-cooney commented 11 months ago

Thanks - feel free to submit a PR for this!

gegallego commented 11 months ago

Hi, I'm also using n_devices > 1 and I've found this bug in many other parts of the code.

Some examples:

https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/components.py#L676

https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/components.py#L754-L755

https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/utils.py#L696-L698

https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/ActivationCache.py#L453

In general, it breaks every time it has to do an operation between two tensors that are stored in different GPUs.

It seems like a more structural issue than just fixing those cases with a .to(X.device). Or am I missing something?

Thanks!

neelnanda-io commented 11 months ago

It's a bit messy. In my opinion the crucial thing is that the model runs. So fixing bugs 1 and 2 seems important

I'm in general kinda fine with some utilities in the library assuming things all live on one device, which is what happened with bugs 3 and 4. Bug 3 is in utils.test_prompt and should be easy to fix by eg moving the values to the CPU. Bug 4 is a messier problem, because it's trying to stack activations across device. It would be easy to fix by adding something to move the activations to the same device but that might give you out of memory errors. One option is to have a method on the cache that moves all activations to the same device (or CPU, which should probably be the default)

Generally, multi device is a thing I want the library's core to support, but I'm fine with some features breaking on it, since it seems costly to need to write everything to be robust to multi device stuff. But if anything specific annoys you and is easy to fix, please send a PR!

On Wed, 25 Oct 2023, 3:02 pm Gerard I. Gállego, @.***> wrote:

Hi, I'm also using n_devices > 1 and I've found this bug in many other parts of the code.

Some examples:

https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/components.py#L676

https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/components.py#L754-L755

https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/utils.py#L696-L698

https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/ActivationCache.py#L453

In general, it breaks every time it has to do an operation between two tensors that are stored in different GPUs.

It seems like a more structural issue than just fixing those cases with a .to(X.device). Or am I missing something?

Thanks!

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/439#issuecomment-1779350184, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKI4PPYKQI3DSHRL77LYBELYVAVCNFSM6AAAAAA6LPS7WCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONZZGM2TAMJYGQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

jiahai-feng commented 9 months ago

I've encountered the same issue. Instead of changing things at various places in the attention block, I added these four lines to HookedTransformer.forward (roughly L546)

                if attention_mask is not None:
                    attention_mask = attention_mask.to(
                        devices.get_device_for_block_index(i, self.cfg)
                    )

I personally think this is slightly more parsimonious.

msakarvadia commented 8 months ago

I wanted to check and see if any progress/PRs have been made on this issue? I am running into this error as well.

alan-cooney commented 8 months ago

Not yet I'm afraid.

There's a task involved here to remove most of the manual device setting throughout the codebase (e.g. to.(device=) and tensor([], device=), as torch handles most of this by default with e.g. model.to() . Then getting multi-gpu support to work everywhere should be v. easy.

I'll peel this off if no-one else does over the next few weeks - it's also a bottleneck for distributed SAE training.

coolvision commented 7 months ago

There's a task involved here to remove most of the manual device setting throughout the codebase (e.g. to.(device=) and tensor([], device=), as torch handles most of this by default with e.g. model.to() . Then getting multi-gpu support to work everywhere should be v. easy.

In this case, as I understand, it would not support distributing a model between several GPUs?

I had similar problems when I tried to load llama-13b on two of GTX1080. This worked with quantization (#486), but then I got similar issues as described above: tensors on different devices.

I made some fixes to move the tensors to the same device: https://github.com/coolvision/TransformerLens/commit/5b250b30abcbf1c4f1d482759082d897e2ef2843

With this fixes the inference does work, with the model distributed between 2 GPUs:

model_name = "meta-llama/Llama-2-13b-chat-hf"

model = AutoModelForCausalLM.from_pretrained(model_name,
                                             torch_dtype=inference_dtype,
                                             device_map = "auto",
                                             load_in_4bit=True)

tokenizer = AutoTokenizer.from_pretrained(model_name)

model_tl = HookedTransformer.from_pretrained(model_name,
                                             hf_model=model,
                                             dtype=inference_dtype,
                                             fold_ln=False,
                                             n_devices=2,
                                             fold_value_biases=False,
                                             center_writing_weights=False,
                                             center_unembed=False,
                                             tokenizer=tokenizer)

model_tl.generate("The capital of Germany is", max_new_tokens=2, temperature=0)

>'The capital of Germany is Berlin.'

|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
| 29%   44C    P8    10W / 250W |   6136MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 25%   41C    P8     9W / 250W |   5312MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

I can add this fixes to my quantization PR #486, or can make a separate PR as well.