facebookresearch / metaseq

Repo for external large-scale work
MIT License
6.46k stars 724 forks source link

Deprecate convert_to_singleton #691

Open andrewPoulton opened 1 year ago

andrewPoulton commented 1 year ago

As noted in #689, convert_to_singleton doesn't produce statedicts with compatible keys (for some unknown reason).

Since reshard_mp can do the same job, without the GPU node requirement of convert_to_singleton, we should deprecate convert_to_singleton.

TODO: Work out dependencies on covert_to_singleton, and identify any special cases it can handle that reshard_mp can't (such as separating out qkv weights, as noted by @tangbinh)

larekrow commented 1 year ago

reshard_mp.py --num-output-parts 1 currently does not work with the OPT weights. Please see #695.

ayeeyecorp commented 1 year ago

Since reshard_mp can do the same job, without the GPU node requirement of convert_to_singleton, we should deprecate convert_to_singleton.

@andrewPoulton Is this true though? I was unable to convert 8 shards successfully to restored.pt using:

python -m metaseq.scripts.reshard_mp \
--input "opt/shards/reshard-model_part-*.pt" \
--output "opt/pt/reshard_no_os_mp8/reshard-model_part-{i}.pt" \
--num-output-parts 1
tangbinh commented 1 year ago

I was unable to convert 8 shards successfully to restored.pt using:

@ayeeyecorp Can you share the stack trace? I suspect it might be related to the fact that the checkpoints available on the OPT page are flattened, which are are not compatible with reshard.mp.

andrewPoulton commented 1 year ago

@tangbinh let's add a flat param check to reshard_*, and raise an error unless user specifically wants to unflatten. I'll create an issue to track in a bit. Happy to own as well.

tangbinh commented 1 year ago

@andrewPoulton I was adding an option to split the KVQ weights in reshard_mp, but I think this is probably not needed for 2 reasons:

  1. This weight splitting has already been included in the script convert_opt_original_pytorch_checkpoint_to_pytorch.py. Previously, there was a bug that basically turned this off, but it has been fixed (see https://github.com/huggingface/transformers/pull/22526).
  2. Previously, we supported both transformer_lm and transformer_lm_megatron models, but the the former has been removed in #633. Therefore, there's no need to split KVQ weights within Metaseq.

Once we fixed #625, I think we can safely remove convert_to_singleton.py as users are able to load OPT checkpoints using reshard_mp.py and convert_opt_original_pytorch_checkpoint_to_pytorch.py and Huggingface Transformers.

ayeeyecorp commented 1 year ago

I was unable to convert 8 shards successfully to restored.pt using:

@ayeeyecorp Can you share the stack trace? I suspect it might be related to the fact that the checkpoints available on the OPT page are flattened, which are are not compatible with reshard.mp.

@andrewPoulton - I did not save the stack trace from that particular test - can redo. However, here is the tail end snippet of the stack trace after running ./metaseq/metaseq/scripts/convert_to_singleton on OPT-175B checkpoints:

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for FlattenParamsWrapper: Missing key(s) in state_dict: "_fpw_module.decoder.layers.0._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.1._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.2._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.3._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.4._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.5._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.6._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.7._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.8._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.9._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.10._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.11._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.12._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.13._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.14._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.15._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.16._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.17._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.18._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.19._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.20._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.21._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.22._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.23._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.24._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.25._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.26._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.27._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.28._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.29._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.30._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.31._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.32._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.33._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.34._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.35._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.36._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.37._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.38._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.39._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.40._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.41._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.42._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.43._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.44._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.45._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.46._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.47._fsdp_wrapped_module.flat_param_0". Unexpected key(s) in state_dict: "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.out_proj.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.out_proj.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc1.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc1.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc2.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc2.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.out_proj.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.out_proj.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc1.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc1.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc2.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc2.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.decoder.layers.2._fsdp_wrapped

The 992 shards were first converted to reshard-model_part-$j.pt using:

for j in {0..7}; do
    python3 -m ./metaseq/metaseq/scripts/reshard_fsdp \
    --input-glob-pattern "./checkpoint_last-model_part-$j-shard*.pt" \
    --output-shard-name "./reshard/reshard-model_part-$j.pt" \
    --num-output-shards 1 --skip-optimizer-state True --unflatten-weights True
done

Should I have set --unflatten-weights False in order for metaseq.scripts.convert_to_singleton and metaseq.scripts.reshard_mp to work correctly?

andrewPoulton commented 1 year ago

@ayeeyecorp Just so I'm clear - you first ran reshard_fsdp on the shards (with unflatten-weights=true), then tried running convert_to_singleton on the consolidated shards? If that's so, then can you try running reshard_mp on the consolidated shards instead?

ayeeyecorp commented 1 year ago

@andrewPoulton

you first ran reshard_fsdp on the shards (with unflatten-weights=true), then tried running convert_to_singleton on the consolidated shards?

Correct, this resulted in the state_dict error

If that's so, then can you try running reshard_mp on the consolidated shards instead?

Will do that again shortly and post stack trace results.

tangbinh commented 1 year ago

@ayeeyecorp May I ask why you want to convert the 8 MP parts of OPT 175B into a singleton? I don't think you would be able to load the singleton into any GPU considering its size, which is about 350GB.

Should I have set --unflatten-weights False in order for metaseq.scripts.convert_to_singleton and metaseq.scripts.reshard_mp to work correctly?

convert_to_singleton expects flattened weights; that's probably why you got Missing key(s) in state_dict. However, reshard_mp expects unflattened weights. As suggested by @andrewPoulton, please try to use reshard_mp instead, as we're deprecating convert_to_singleton.

ayeeyecorp commented 1 year ago

@andrewPoulton

I started over earlier today from the 992 shards (resetting my environment per the instructions here using Python3.8) and verified that the 8 consolidated FSDP shards had the correct md5sum. Upon confirmation, I converted the checkpoints, to eliminate use of MP, to 1 with the reshard_mp.py script, with no issues this time, using:

python -m metaseq.scripts.reshard_mp \
    --input "/path/to/resharded/checkpoints/reshard-model_part-*.pt" \
    --output "/path/to/mp/resharded/checkpoints/reshard-model_part-{i}.pt" \
    --num-output-parts 1

Not sure what the original problem was. The md5sum of the single checkpoint (325.2 GB) was: 06e7e7ed424db3834ccd1a776d82ff14

The subsequent step to convert to hugging face using:

python3 transformers.src.transformers.models.opt.convert_opt_original_pytorch_checkpoint_to_pytorch --pytorch_dump_folder_path ~/opt_meta/hugging/ --hf_config config.json --fairseq_path ~/opt_meta/single_shard/reshard-model_part-0.pt,

failed after 1+ hour with the following stack trace:

       size mismatch for decoder.layers.8.final_layer_norm.weight: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.8.final_layer_norm.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.self_attn.k_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.9.self_attn.k_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.self_attn.v_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.9.self_attn.v_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.self_attn.q_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.9.self_attn.q_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.self_attn.out_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.9.self_attn.out_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.self_attn_layer_norm.weight: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.self_attn_layer_norm.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.fc1.weight: copying a param with shape torch.Size([49152, 12288]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
        size mismatch for decoder.layers.9.fc1.bias: copying a param with shape torch.Size([49152]) from checkpoint, the shape in current model is torch.Size([3072]).
        size mismatch for decoder.layers.9.fc2.weight: copying a param with shape torch.Size([12288, 49152]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
        size mismatch for decoder.layers.9.fc2.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.final_layer_norm.weight: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.final_layer_norm.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.self_attn.k_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.10.self_attn.k_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.self_attn.v_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.10.self_attn.v_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.self_attn.q_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.10.self_attn.q_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.self_attn.out_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.10.self_attn.out_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.self_attn_layer_norm.weight: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.self_attn_layer_norm.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.fc1.weight: copying a param with shape torch.Size([49152, 12288]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
        size mismatch for decoder.layers.10.fc1.bias: copying a param with shape torch.Size([49152]) from checkpoint, the shape in current model is torch.Size([3072]).
        size mismatch for decoder.layers.10.fc2.weight: copying a param with shape torch.Size([12288, 49152]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
        size mismatch for decoder.layers.10.fc2.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.final_layer_norm.weight: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.final_layer_norm.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.self_attn.k_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.11.self_attn.k_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.self_attn.v_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.11.self_attn.v_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.self_attn.q_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.11.self_attn.q_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.self_attn.out_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.11.self_attn.out_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.self_attn_layer_norm.weight: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.self_attn_layer_norm.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.fc1.weight: copying a param with shape torch.Size([49152, 12288]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
        size mismatch for decoder.layers.11.fc1.bias: copying a param with shape torch.Size([49152]) from checkpoint, the shape in current model is torch.Size([3072]).
        size mismatch for decoder.layers.11.fc2.weight: copying a param with shape torch.Size([12288, 49152]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
        size mismatch for decoder.layers.11.fc2.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.final_layer_norm.weight: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.final_layer_norm.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
(venv) [ec2-user@ip-172-31-28-254 opt]$

I followed @patrickvonplaten conversion instructions found here and generated a config.json with the following:

from transformers import OPTConfig
num_layers = 12
num_heads = 12
d_model = 768
config = OPTConfig(hidden_size=d_model, num_attention_heads=num_heads, num_hidden_layers=num_layers, ffn_dim=4*d_model)
config.save_pretrained("./")  # <- this will create a `config.json` in your current folder

Thoughts on what could be going wrong with the HF conversion? I will re-run the operation overnight and log the full failure stack trace.

@tangbinh - thank you for the clarification. I am converting the 8 MP parts of OPT 175B into a singleton to run quantization experiments against

tangbinh commented 1 year ago

@ayeeyecorp For OPT 175B, we should have num_layers = 96, num_heads = 96, and d_model = 12288.

ayeeyecorp commented 1 year ago

@ayeeyecorp For OPT 175B, we should have num_layers = 96, num_heads = 96, and d_model = 12288.

@tangbinh that was quick! brilliant, will give that a go now. I blindly used values from HF... thank you

ayeeyecorp commented 1 year ago

After updating instance to 1TB+ of RAM... I successfully generated a .bin file using src.transformers.models.opt.convert_opt_original_pytorch_checkpoint_to_pytorch!

Thanks for the support.