facebookresearch / metaseq

Repo for external large-scale work
MIT License
6.51k stars 725 forks source link

Error when run reshard_fsdp on opt-IML 30b for inference #567

Open zhanghaoie opened 1 year ago

zhanghaoie commented 1 year ago

Traceback (most recent call last): File "/opt/conda/envs/alpa/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/opt/conda/envs/alpa/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/build/metaseq/metaseq/scripts/reshard_fsdp.py", line 245, in fire.Fire(reshard_fsdp_checkpoints) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/fire/core.py", line 475, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/opt/conda/envs/alpa/lib/python3.8/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/build/metaseq/metaseq/scripts/reshard_fsdp.py", line 53, in reshard_fsdp_checkpoints resharded_state_dicts = reshard_fsdp_state_dicts( File "/build/metaseq/metaseq/scripts/reshard_fsdp.py", line 79, in reshard_fsdp_state_dicts shard_metadata=[s["shard_metadata"] for s in shard_state_dicts], File "/build/metaseq/metaseq/scripts/reshard_fsdp.py", line 79, in shard_metadata=[s["shard_metadata"] for s in shard_state_dicts], KeyError: 'shard_metadata'

Charles-ux-bit commented 1 year ago

Can you show me the code please? I'am also wondering how to run opt-iml 30b code for inference.

sriniiyer commented 1 year ago

Could you share the motivation and the command? Thanks!

Charles-ux-bit commented 1 year ago

Thanks for your reply! I want to use the OPT-IML 30B model in my python 3.7 environment, but I only see the model checkpoint. Without the inference code, I don't know how to use the checkpoint. Can you share me a code example? Preferably using the huggingface library. Thanks!

Charles-ux-bit commented 1 year ago

By the way, I successfully ran the OPT 30B inference code from huggingface.

sriniiyer commented 1 year ago

@Charles-ux-bit We are working on uploading it to huggingface and it will be available soon, hopefully in a few days. An alternative that works right now is to use the metaseq repo for inference - https://github.com/facebookresearch/metaseq

Charles-ux-bit commented 1 year ago

But I didn't see a explicit way in netaseq for running inference code, which way should I select? using Alpa or Colossal-AI? Is there an example? Thanks!

sahajgg commented 1 year ago

@Charles-ux-bit Assuming 30B-IML checkpoint is same as 30B checkpoint in architecture, a hack that works is to copy "shard_metadata" from the later and then use reshard_fsdp script

xiangjjj commented 1 year ago

@Charles-ux-bit Assuming 30B-IML checkpoint is same as 30B checkpoint in architecture, a hack that works is to copy "shard_metadata" from the later and then use reshard_fsdp script

The OPT-30B model checkpoint here does not have the original checkpoint, but only the resharded version. Do you think there is an easy way to carry out your hack?

xiangjjj commented 1 year ago

@sriniiyer , could you fix the issue in the checkpoint or upload resharded version to unblock us from running inference on OPT-IML?

sriniiyer commented 1 year ago

@xiangjjj I need the exact command you folks are trying to run - and what you are trying to achieve - in order to unblock you.

xiangjjj commented 1 year ago

@sriniiyer My goal is to setup metaseq on my system to run inference with OPT-IML-30B and OPT-IML-175B. I'm able to run inference with OPT-30B with the existing reshard checkpoints here. Because OPT-IML-30B does not have the resharded files, I'm trying to reshard it using the following script:

CHECKPOINT_PATH=/home/ubuntu/llm/OPT-IML/OPT-IML-30B
for j in {0..1}; do
    python -m metaseq.scripts.reshard_fsdp \
    --input-glob-pattern "$CHECKPOINT_PATH/checkpoint_1_4000.pt-model_part-$j.pt" \
    --output-shard-name "$CHECKPOINT_PATH/reshard-model_part-$j.pt" \
    --num-output-shards 1 --skip-optimizer-state True --unflatten-weights True
done

I got the following error message with KeyError: 'shard_metadata'.

2023-01-23 13:09:33,645 | metaseq.scripts.reshard_fsdp | Found 1 sharded checkpoints (/home/ubuntu/llm/OPT-IML/OPT-IML-30B/checkpoint_1_4000.pt-model_part-0.pt to /home/ubuntu/llm/OPT-IML/OPT-IML-30B/checkpoint_1_4000.pt-model_part-0.pt)
2023-01-23 13:09:33,645 | metaseq.scripts.reshard_fsdp | Loading all sharded checkpoints to CPU
2023-01-23 13:09:50,877 | metaseq.scripts.reshard_fsdp | Resharding state dicts into 1 shard(s)
Traceback (most recent call last):
  File "/opt/conda/envs/py3.8/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/envs/py3.8/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/metaseq/metaseq/scripts/reshard_fsdp.py", line 242, in <module>
    fire.Fire(reshard_fsdp_checkpoints)
  File "/opt/conda/envs/py3.8/lib/python3.8/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/envs/py3.8/lib/python3.8/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/opt/conda/envs/py3.8/lib/python3.8/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/ubuntu/metaseq/metaseq/scripts/reshard_fsdp.py", line 53, in reshard_fsdp_checkpoints
    resharded_state_dicts = reshard_fsdp_state_dicts(
  File "/home/ubuntu/metaseq/metaseq/scripts/reshard_fsdp.py", line 76, in reshard_fsdp_state_dicts
    shard_metadata=[s["shard_metadata"] for s in shard_state_dicts],
  File "/home/ubuntu/metaseq/metaseq/scripts/reshard_fsdp.py", line 76, in <listcomp>
    shard_metadata=[s["shard_metadata"] for s in shard_state_dicts],
KeyError: 'shard_metadata'
2023-01-23 13:09:54,481 | metaseq.scripts.reshard_fsdp | Found 1 sharded checkpoints (/home/ubuntu/llm/OPT-IML/OPT-IML-30B/checkpoint_1_4000.pt-model_part-1.pt to /home/ubuntu/llm/OPT-IML/OPT-IML-30B/checkpoint_1_4000.pt-model_part-1.pt)
2023-01-23 13:09:54,481 | metaseq.scripts.reshard_fsdp | Loading all sharded checkpoints to CPU
2023-01-23 13:10:11,729 | metaseq.scripts.reshard_fsdp | Resharding state dicts into 1 shard(s)
Traceback (most recent call last):
  File "/opt/conda/envs/py3.8/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/envs/py3.8/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/metaseq/metaseq/scripts/reshard_fsdp.py", line 242, in <module>
    fire.Fire(reshard_fsdp_checkpoints)
  File "/opt/conda/envs/py3.8/lib/python3.8/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/envs/py3.8/lib/python3.8/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/opt/conda/envs/py3.8/lib/python3.8/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/ubuntu/metaseq/metaseq/scripts/reshard_fsdp.py", line 53, in reshard_fsdp_checkpoints
    resharded_state_dicts = reshard_fsdp_state_dicts(
  File "/home/ubuntu/metaseq/metaseq/scripts/reshard_fsdp.py", line 76, in reshard_fsdp_state_dicts
    shard_metadata=[s["shard_metadata"] for s in shard_state_dicts],
  File "/home/ubuntu/metaseq/metaseq/scripts/reshard_fsdp.py", line 76, in <listcomp>
    shard_metadata=[s["shard_metadata"] for s in shard_state_dicts],
KeyError: 'shard_metadata'

It appears the OPT-IML checkpoints are incomplete.

Please let me know if my understanding is correct and how to resolve this issue.

Thank you so much!

tangbinh commented 1 year ago

@zhanghaoie @xiangjjj The OPT-IML checkpoints have already been consolidated, so there's no need to run reshard_fsdp. You can directly load these checkpoints, for example, using some interactive scripts.