Closed vishakudupa closed 1 year ago
@vishakudupa Hi! Yeah, readme is plain wrong. Here I've compiled a minimal example of how to dispatch on colab where there's enough VRAM but not enough RAM. An extract from it:
from huggingface_hub import hf_hub_download
import torch
device_map = tp.infer_sharded_device_map(model)
for i in range(1, 34):
# Load a shard
state_dict = torch.load(hf_hub_download(repo_id="decapoda-research/llama-7b-hf", filename=f"pytorch_model-{i:05d}-of-00033.bin"))
# Convert a shard
state_dict = tp.convert_state_dict(
state_dict,
tensor_parallel_config=model.tensor_parallel_config,
world_size=len(model.devices),
for_pretrained=True,
)
# Load a converted shard into model
# load_state_dict doesn't properly work with meta: https://discuss.pytorch.org/t/discrepancy-between-loading-models-with-meta-tensors-and-normal-load-from-state-dict/168295
# This particular code is borrowed froom 'accelerate': https://github.com/huggingface/accelerate/blob/0226f750257b3bf2cadc4f189f9eef0c764a0467/src/accelerate/utils/modeling.py#LL1002C14-L1002C14
for param_name, param in state_dict.items():
module_name = param_name
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
param_device = device_map[module_name]
accelerate.utils.set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch.float16)
I'll be putting something similar into readme soon.
@BlackSamorez Thank You, will try it out.
Hi, I'm trying to run MPT-7B instruct using tensor_paralle, but facing the below issue. It looks like there is a mismatch between the embedding size of the model and the last input_dim of the MLP. Can you please take a look at it?
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_23/2317293603.py in <module>
---> 10 model.generate(**tok, min_new_tokens=1024, max_new_tokens=1024)
/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
28 return cast(F, decorate_context)
29
/opt/conda/lib/python3.7/site-packages/transformers/generation/utils.py in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
1526 synced_gpus=synced_gpus,
1527 streamer=streamer,
-> 1528 **model_kwargs,
1529 )
1530
/opt/conda/lib/python3.7/site-packages/transformers/generation/utils.py in greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
2337 return_dict=True,
2338 output_attentions=output_attentions,
-> 2339 output_hidden_states=output_hidden_states,
2340 )
2341
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
/opt/conda/lib/python3.7/site-packages/tensor_parallel/pretrained_model.py in forward(self, *args, **kwargs)
76
77 def forward(self, *args, **kwargs):
---> 78 return self.wrapped_model(*args, **kwargs)
79
80 def state_dict(self, *args, **kwargs):
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
/opt/conda/lib/python3.7/site-packages/tensor_parallel/tensor_parallel.py in forward(self, *args, **kwargs)
128 inputs, kwargs_tup = self.prepare_args_kwargs_for_forward(*args, **kwargs)
129 if self.all_cuda and not TENSOR_PARALLEL_USE_NATIVE:
--> 130 return parallel_apply(self.module_shards, inputs, kwargs_tup, self.devices)[self.output_device_index]
131 else:
132 return parallel_apply_simple(self.module_shards, inputs, kwargs_tup, self.devices)[self.output_device_index]
/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
87 output = results[i]
88 if isinstance(output, ExceptionWrapper):
---> 89 output.reraise()
90 outputs.append(output)
91 return outputs
/opt/conda/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
541 # instantiate since we don't know how to
542 raise RuntimeError(msg) from None
--> 543 raise exception
544
545
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
output = module(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-instruct/665b2900b1ceabbf2723580f03f659f70fcba26b/modeling_mpt.py", line 238, in forward
logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (480x4096 and 2048x50432)
@vishakudupa Yeah, I can reproduce this issue.
The problem there is that they're using pure torch.nn.Functional.linear
, which can't be caught by tensor_parallel
autoconfig since it only works proper torch.nn.Linear
for that regard.
A custom config is needed to run this. You can see how it's done in slicing_configs.py. If you're willing to try and create a config for this model, I'm eager to help!
@BlackSamorez Thanks for the confirmation. I'll try to create a config for the MPT model. I have one question: will I be able to train using https://github.com/huggingface/peft after I get it to work? Because my end goal is to fine-tune for a downstream application.
--Edit
I can see add_lora_rules
. So, it would work with PEFT as well.
Can you please tell me how can I start with the config? where should I look to write the rules
?
Hello,
I'm trying to load MPT-7B, where I have a limited amount of memory and I'm trying to follow the instructions mentioned in the README, but it is throwing an error because it is trying to load the state_dict instead of the in tp.tensor_parallel nn.Model. Please provide an example that works, or suggest something that I need to change for it to work.
-- Thank You