BlackSamorez / tensor_parallel

Automatically split your PyTorch models on multiple GPUs for training & inference
MIT License
629 stars 39 forks source link

Error in README.Md, hence not able to load model with limited memory. #77

Closed vishakudupa closed 1 year ago

vishakudupa commented 1 year ago

Hello,

I'm trying to load MPT-7B, where I have a limited amount of memory and I'm trying to follow the instructions mentioned in the README, but it is throwing an error because it is trying to load the state_dict instead of the in tp.tensor_parallel nn.Model. Please provide an example that works, or suggest something that I need to change for it to work.

# Load partial state_dict for MyModel
state_dict = torch.load("my_model_part_1_of_5.bin")

# Convert it into a tensor_parallel state_dict
tensor_parallel_state_dict = tp.tensor_parallel(
    state_dict,
    tensor_parallel_config=model.tensor_parallel_config,
    world_size=len(model.devices),
)

-- Thank You

BlackSamorez commented 1 year ago

@vishakudupa Hi! Yeah, readme is plain wrong. Here I've compiled a minimal example of how to dispatch on colab where there's enough VRAM but not enough RAM. An extract from it:

from huggingface_hub import hf_hub_download
import torch

device_map = tp.infer_sharded_device_map(model)

for i in range(1, 34):
    # Load a shard
    state_dict = torch.load(hf_hub_download(repo_id="decapoda-research/llama-7b-hf", filename=f"pytorch_model-{i:05d}-of-00033.bin"))

    # Convert a shard
    state_dict = tp.convert_state_dict(
        state_dict,
        tensor_parallel_config=model.tensor_parallel_config,
        world_size=len(model.devices),
        for_pretrained=True,
    )

    # Load a converted shard into model
    #  load_state_dict doesn't properly work with meta: https://discuss.pytorch.org/t/discrepancy-between-loading-models-with-meta-tensors-and-normal-load-from-state-dict/168295
    #  This particular code is borrowed froom 'accelerate': https://github.com/huggingface/accelerate/blob/0226f750257b3bf2cadc4f189f9eef0c764a0467/src/accelerate/utils/modeling.py#LL1002C14-L1002C14
    for param_name, param in state_dict.items():
        module_name = param_name

        while len(module_name) > 0 and module_name not in device_map:
            module_name = ".".join(module_name.split(".")[:-1])
        param_device = device_map[module_name]

        accelerate.utils.set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch.float16)

I'll be putting something similar into readme soon.

vishakudupa commented 1 year ago

@BlackSamorez Thank You, will try it out.

vishakudupa commented 1 year ago

Hi, I'm trying to run MPT-7B instruct using tensor_paralle, but facing the below issue. It looks like there is a mismatch between the embedding size of the model and the last input_dim of the MLP. Can you please take a look at it?

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_23/2317293603.py in <module>
---> 10 model.generate(**tok, min_new_tokens=1024, max_new_tokens=1024)

/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     25         def decorate_context(*args, **kwargs):
     26             with self.clone():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)
     29 

/opt/conda/lib/python3.7/site-packages/transformers/generation/utils.py in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
   1526                 synced_gpus=synced_gpus,
   1527                 streamer=streamer,
-> 1528                 **model_kwargs,
   1529             )
   1530 

/opt/conda/lib/python3.7/site-packages/transformers/generation/utils.py in greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2337                 return_dict=True,
   2338                 output_attentions=output_attentions,
-> 2339                 output_hidden_states=output_hidden_states,
   2340             )
   2341 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/tensor_parallel/pretrained_model.py in forward(self, *args, **kwargs)
     76 
     77     def forward(self, *args, **kwargs):
---> 78         return self.wrapped_model(*args, **kwargs)
     79 
     80     def state_dict(self, *args, **kwargs):

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/tensor_parallel/tensor_parallel.py in forward(self, *args, **kwargs)
    128         inputs, kwargs_tup = self.prepare_args_kwargs_for_forward(*args, **kwargs)
    129         if self.all_cuda and not TENSOR_PARALLEL_USE_NATIVE:
--> 130             return parallel_apply(self.module_shards, inputs, kwargs_tup, self.devices)[self.output_device_index]
    131         else:
    132             return parallel_apply_simple(self.module_shards, inputs, kwargs_tup, self.devices)[self.output_device_index]

/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
     87         output = results[i]
     88         if isinstance(output, ExceptionWrapper):
---> 89             output.reraise()
     90         outputs.append(output)
     91     return outputs

/opt/conda/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
    541             # instantiate since we don't know how to
    542             raise RuntimeError(msg) from None
--> 543         raise exception
    544 
    545 

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-instruct/665b2900b1ceabbf2723580f03f659f70fcba26b/modeling_mpt.py", line 238, in forward
    logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (480x4096 and 2048x50432)
BlackSamorez commented 1 year ago

@vishakudupa Yeah, I can reproduce this issue. The problem there is that they're using pure torch.nn.Functional.linear, which can't be caught by tensor_parallel autoconfig since it only works proper torch.nn.Linear for that regard. A custom config is needed to run this. You can see how it's done in slicing_configs.py. If you're willing to try and create a config for this model, I'm eager to help!

vishakudupa commented 1 year ago

@BlackSamorez Thanks for the confirmation. I'll try to create a config for the MPT model. I have one question: will I be able to train using https://github.com/huggingface/peft after I get it to work? Because my end goal is to fine-tune for a downstream application.

--Edit I can see add_lora_rules. So, it would work with PEFT as well. Can you please tell me how can I start with the config? where should I look to write the rules?