facebookresearch / metaseq

Repo for external large-scale work
MIT License
6.45k stars 723 forks source link

No longer able to load provided OPT checkpoint after recent changes #383

Open EIFY opened 1 year ago

EIFY commented 1 year ago

🐛 Bug

No longer able to load provided OPT checkpoint after recent changes

To Reproduce

Edit metaseq/service/constants.py as before, in my case:

MAX_SEQ_LEN = 2048
BATCH_SIZE = 2048  # silly high bc we dynamically batch by MAX_BATCH_TOKENS
MAX_BATCH_TOKENS = 3072
DEFAULT_PORT = 6010
MODEL_PARALLEL = 1
TOTAL_WORLD_SIZE = 1
MAX_BEAM = 16

try:
    # internal logic denoting where checkpoints are in meta infrastructure
    from metaseq_internal.constants import CHECKPOINT_FOLDER
except ImportError:
    CHECKPOINT_FOLDER = "/home/jason_chou/redspot_home/350m/"
(...)

where

$ pwd
/home/jason_chou/redspot_home
$ ls 350m/
dict.txt  gpt2-merges.txt  gpt2-vocab.json  reshard.pt

and then run metaseq-api-local, but it no longer works:

$ metaseq-api-local
2022-10-05 22:19:25 | INFO | metaseq.hub_utils | loading model(s) from /home/jason_chou/redspot_home/350m/reshard.pt
2022-10-05 22:19:26 | INFO | metaseq.checkpoint_utils | Done reading from disk
Traceback (most recent call last):
  File "/home/jason_chou/.conda/envs/user/bin/metaseq-api-local", line 8, in <module>
    sys.exit(cli_main())
  File "/home/default_user/metaseq/metaseq_cli/interactive_hosted.py", line 370, in cli_main
    distributed_utils.call_main(cfg, worker_main, namespace_args=args)
  File "/home/default_user/metaseq/metaseq/distributed/utils.py", line 279, in call_main
    return main(cfg, **kwargs)
  File "/home/default_user/metaseq/metaseq_cli/interactive_hosted.py", line 176, in worker_main
    models = generator.load_model()  # noqa: F841
  File "/home/default_user/metaseq/metaseq/hub_utils.py", line 565, in load_model
    models, _model_args, _task = _load_checkpoint()
  File "/home/default_user/metaseq/metaseq/hub_utils.py", line 548, in _load_checkpoint
    return checkpoint_utils.load_model_ensemble_and_task(
  File "/home/default_user/metaseq/metaseq/checkpoint_utils.py", line 482, in load_model_ensemble_and_task
    model = build_model_hook(cfg, task)
  File "/home/default_user/metaseq/metaseq/hub_utils.py", line 538, in _build_model
    setattr(cfg["model"], "inference", True)
  File "/home/default_user/.conda/envs/user/lib/python3.10/site-packages/omegaconf/dictconfig.py", line 337, in __setattr__
    raise e
  File "/home/default_user/.conda/envs/user/lib/python3.10/site-packages/omegaconf/dictconfig.py", line 334, in __setattr__
    self.__set_impl(key, value)
  File "/home/default_user/.conda/envs/user/lib/python3.10/site-packages/omegaconf/dictconfig.py", line 318, in __set_impl
    self._set_item_impl(key, value)
  File "/home/default_user/.conda/envs/user/lib/python3.10/site-packages/omegaconf/basecontainer.py", line 511, in _set_item_impl
    self._validate_set(key, value)
  File "/home/default_user/.conda/envs/user/lib/python3.10/site-packages/omegaconf/dictconfig.py", line 180, in _validate_set
    target = self._get_node(key) if key is not None else self
  File "/home/default_user/.conda/envs/user/lib/python3.10/site-packages/omegaconf/dictconfig.py", line 465, in _get_node
    self._validate_get(key)
  File "/home/default_user/.conda/envs/user/lib/python3.10/site-packages/omegaconf/dictconfig.py", line 166, in _validate_get
    self._format_and_raise(
  File "/home/default_user/.conda/envs/user/lib/python3.10/site-packages/omegaconf/base.py", line 190, in _format_and_raise
    format_and_raise(
  File "/home/default_user/.conda/envs/user/lib/python3.10/site-packages/omegaconf/_utils.py", line 821, in format_and_raise
    _raise(ex, cause)
  File "/home/default_user/.conda/envs/user/lib/python3.10/site-packages/omegaconf/_utils.py", line 719, in _raise
    raise ex.with_traceback(sys.exc_info()[2])  # set end OC_CAUSE=1 for full backtrace
omegaconf.errors.ConfigAttributeError: Key 'inference' is not in struct
    full_key: model.inference
    object_type=dict

Apparently this can be traced back to when setattr(cfg["model"], "inference", True) was added (https://github.com/facebookresearch/metaseq/pull/356). However, another issue surfaced even with that line commented out:

$ metaseq-api-local
2022-10-05 22:23:31 | INFO | metaseq.hub_utils | loading model(s) from /home/jason_chou/redspot_home/350m/reshard.pt
2022-10-05 22:23:31 | INFO | metaseq.checkpoint_utils | Done reading from disk
Traceback (most recent call last):
  File "/home/jason_chou/.conda/envs/user/bin/metaseq-api-local", line 8, in <module>
    sys.exit(cli_main())
  File "/home/default_user/metaseq/metaseq_cli/interactive_hosted.py", line 370, in cli_main
    distributed_utils.call_main(cfg, worker_main, namespace_args=args)
  File "/home/default_user/metaseq/metaseq/distributed/utils.py", line 279, in call_main
    return main(cfg, **kwargs)
  File "/home/default_user/metaseq/metaseq_cli/interactive_hosted.py", line 176, in worker_main
    models = generator.load_model()  # noqa: F841
  File "/home/default_user/metaseq/metaseq/hub_utils.py", line 565, in load_model
    models, _model_args, _task = _load_checkpoint()
  File "/home/default_user/metaseq/metaseq/hub_utils.py", line 548, in _load_checkpoint
    return checkpoint_utils.load_model_ensemble_and_task(
  File "/home/default_user/metaseq/metaseq/checkpoint_utils.py", line 487, in load_model_ensemble_and_task
    model.load_state_dict(state["model"], strict=strict)
  File "/home/default_user/.conda/envs/user/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for TransformerLanguageModel:
        Missing key(s) in state_dict: "decoder.layer_norm.weight", "decoder.layer_norm.bias". 

which seems to be due to recent cleanup PRs (https://github.com/facebookresearch/metaseq/pull/366, https://github.com/facebookresearch/metaseq/pull/380, https://github.com/facebookresearch/metaseq/pull/381).

Expected behavior

metaseq-api-local up & running

Environment

stephenroller commented 1 year ago

@suchenzang

suchenzang commented 1 year ago

Hm, none of the cleanup PRs should have touched state dict logic, much less layer norms. The last time state dicts were touched was in https://github.com/facebookresearch/metaseq/pull/229 I think.

@EIFY do you see this same error in the 125m model? 350m was the only one trained without model parallelism, which has caused some issues in the past with integration.

EIFY commented 1 year ago

Hmm, but https://github.com/facebookresearch/metaseq/pull/229 was merged in Jul 16. I can try git bisect tomorrow but I am certain that the 350m model worked for me in Sep.

I haven't been able to run non-model parallelism models due to another issue (https://github.com/facebookresearch/metaseq/issues/378) 🙃

ruanslv commented 1 year ago

I did a bisect, this is commit that started causing the error: https://github.com/facebookresearch/metaseq/commit/493e6017c18f7c2d3cd697693e6f9e33592f3612

cc @lilisierrayu

ruanslv commented 1 year ago

After commenting out line suggested, second error is caused by this commit in particular https://github.com/facebookresearch/metaseq/commit/c4b33ba6e2cd9b33539bbb5a35d831096bde3282

ruanslv commented 1 year ago

Ok did a bit of digging with @suchenzang, here is the summary:

tangbinh commented 1 year ago

I think the first issue can be fixed by a one-line change (see this OmegaConf documentation):

with omegaconf.open_dict(cfg):
    setattr(cfg["model"], "inference", True)
andchir commented 1 year ago

Missing key(s) in state_dict: "decoder.layer_norm.weight", "decoder.layer_norm.bias".

There is a solution?

ruanslv commented 1 year ago

@andchir we haven't retrained the 350M model yet but if locally you set self.layer_norm = None in metaseq/models/transformer_decoder.py it should work

andchir commented 1 year ago

@ruanslv Thanks for the answer. It helped, the error does not occur. But I am getting strange text generation results. Example:

The technology world is reeling after Facebook ($FB) announced today are have have have have are are have have have have have have are have have have are have are have have have have are have are are are are are have have have have are have have have are are have have are have have are are have have have are are have have have are have have have have are have have have have are have have have have have have have have have have have are have are are have have have are have have have have have have are are have have have have are have ...

I think I should use a different model. Can you help me set up the constants? I don't understand what I should specify in the parameter if the model has only parts. MODEL_FILE = os.path.join(CHECKPOINT_FOLDER, "reshard.pt") # I don't have such a file, I only have "reshard-model_part-0.pt", "..._part-1.pt"

I am trying to use OPT-1.3B.

EIFY commented 1 year ago

Just curious: before the breaking change https://github.com/facebookresearch/metaseq/commit/c4b33ba6e2cd9b33539bbb5a35d831096bde3282, we had https://github.com/facebookresearch/metaseq/blob/50dbe6077bbb977cdd2a7b02ce778ffcf29e829e/metaseq/model_parallel/models/transformer_lm.py#L111-L112 where I believe args.decoder_normalize_before does two things:

  1. Switching from post-norm to pre-norm transformer
  2. Creating the final layer norm (that the 350M model accidentally left out): https://github.com/facebookresearch/metaseq/blob/50dbe6077bbb977cdd2a7b02ce778ffcf29e829e/metaseq/models/transformer_decoder.py#L178-L183

Was the stability issue fixed by 1 & 2 together, or 1 alone? If 1 alone was sufficient, what is the rationale for the final layer norm? Evidently, the 350M model training was stable without it 😅

I also noticed that in comparison to RobertaLMHead, self.dense, self.activation_fn, and self.bias for the final projection back to size of vocabulary are eliminated. I don't know if there are history / rationales / experiments behind these decisions.