facebookresearch / metaseq

Repo for external large-scale work
MIT License
6.51k stars 725 forks source link

layer_norm is fp32, can't be wrapped inside half precision layers. #66

Open BlackSamorez opened 2 years ago

BlackSamorez commented 2 years ago

🐛 Bug

I encountered it when running 6.7b model with MODEL_PARALLEL = 2 and TOTAL_WORLD_SIZE = 2 with single_node_init

When running interactive_cli i encountered a problem in fairseq parameter flattening:

I inserted some prints and found that layer_norm weights and offsets were full precision and could not be flattened with other fp16 parameters.

Indeed in metaseq layer_norm definiton those are fp32

By inserting explicit cast to fp16 I managed to start the model. Is this intended behavior? Am I missing something?

Error

Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/workspace/metaseq/metaseq_cli/interactive_cli.py", line 115, in <module>
    cli_main()
  File "/workspace/metaseq/metaseq_cli/interactive_cli.py", line 111, in cli_main
    dist_utils.call_main(cfg, worker_main, namespace_args=args)
  File "/workspace/metaseq/metaseq/distributed/utils.py", line 256, in call_main
    return _spawn_helper(main, cfg, kwargs)
  File "/workspace/metaseq/metaseq/distributed/utils.py", line 234, in _spawn_helper
    retval = distributed_main(-1, main, cfg, kwargs)
  File "/workspace/metaseq/metaseq/distributed/utils.py", line 203, in distributed_main
    main(cfg, **kwargs)
  File "/workspace/metaseq/metaseq_cli/interactive_cli.py", line 66, in worker_main
    models = generator.load_model()  # noqa: F841
  File "/workspace/metaseq/metaseq/hub_utils.py", line 485, in load_model
    models, _model_args, _task = checkpoint_utils.load_model_ensemble_and_task(
  File "/workspace/metaseq/metaseq/checkpoint_utils.py", line 503, in load_model_ensemble_and_task
    model = build_model_hook(cfg, task)
  File "/workspace/metaseq/metaseq/hub_utils.py", line 474, in _build_model
    model = task.build_model(cfg.model).half().cuda()
  File "/workspace/metaseq/metaseq/tasks/language_modeling.py", line 164, in build_model
    model = super().build_model(args)
  File "/workspace/metaseq/metaseq/tasks/base_task.py", line 560, in build_model
    model = models.build_model(args, self)
  File "/workspace/metaseq/metaseq/models/__init__.py", line 89, in build_model
    return model.build_model(cfg, task)
  File "/workspace/metaseq/metaseq/model_parallel/models/transformer_lm.py", line 58, in build_model
    decoder = ModelParallelTransformerDecoder(
  File "/workspace/metaseq/metaseq/models/transformer.py", line 409, in __init__
    self.build_decoder_layer(
  File "/workspace/metaseq/metaseq/models/transformer.py", line 552, in build_decoder_layer
    layer = fsdp_wrap(
  File "/workspace/metaseq/metaseq/distributed/fully_sharded_data_parallel.py", line 141, in fsdp_wrap
    return wrap(module, **kwargs)
  File "/workspace/fairscale/fairscale/nn/wrap/auto_wrap.py", line 170, in wrap
    return ConfigAutoWrap.wrapper_cls(module, **wrap_overrides)
  File "/workspace/metaseq/metaseq/distributed/fully_sharded_data_parallel.py", line 48, in __init__
    super().__init__(*args, **kwargs)
  File "/workspace/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 342, in __init__
    self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=to_be_flatten_params)
  File "/workspace/fairscale/fairscale/nn/misc/flatten_params_wrapper.py", line 211, in __init__
    params, param_infos, shared_param_infos = self._init_flatten_params(new_p_set)
  File "/workspace/fairscale/fairscale/nn/misc/flatten_params_wrapper.py", line 278, in _init_flatten_params
    assert len(set(p.dtype for p in params)) == 1, "expects all parameters to have same dtype"
AssertionError: expects all parameters to have same dtype

Expected behavior

Environment

stephenroller commented 2 years ago

Can you report your fairscale version?

BlackSamorez commented 2 years ago

0.4.1 Full Dockerfile of container in which I work

FROM nvcr.io/nvidia/pytorch:21.06-py3
WORKDIR /workspace
RUN pip3 install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

RUN git clone https://github.com/NVIDIA/apex.git
WORKDIR apex
RUN git checkout e2083df5eb96643c61613b9df48dd4eea6b07690
ENV TORCH_CUDA_ARCH_LIST="7.5"
RUN pip3 install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_adam" --global-option="--xentropy" --global-option="--fast_multihead_attn" ./

WORKDIR /workspace
RUN git clone --branch fairseq_v2 https://github.com/ngoyal2707/Megatron-LM.git
WORKDIR Megatron-LM
RUN pip3 install six regex
RUN pip3 install -e .

WORKDIR /workspace
RUN git clone https://github.com/facebookresearch/fairscale.git
WORKDIR fairscale
RUN git checkout prefetch_fsdp_params_simple
RUN pip3 install -e .

WORKDIR /workspace
RUN git clone https://github.com/facebookresearch/metaseq.git
WORKDIR metaseq
RUN pip3 install -e .

# turn on pre-commit hooks
RUN pre-commit install

RUN pip install --upgrade numpy

WORKDIR /workspace
Mrs-Hudson commented 2 years ago

I am facing the same issue with the 6.7B model. Is it okay to have MODEL_PARALLEL = 2 and TOTAL_WORLD_SIZE = 2?

Mrs-Hudson commented 2 years ago

@BlackSamorez were you able to solve this?

BlackSamorez commented 2 years ago

@BlackSamorez were you able to solve this?

As I said, I simply inserted .half() inside layer_norm creation (lines 35 and 36) and it started working, but it's a workaround rather than a solution

Mrs-Hudson commented 2 years ago

Thanks, this helped me get the cli server up but the output with the 6.7B model is still gibberish and the first logit is positive @stephenroller @suchenzang

Mrs-Hudson commented 2 years ago

Example output with 6.7B : {'prompt': 'Today is a beautiful day and I want to', 'temperature': 0.9, 'max_tokens': 50, 'min_tokens': 4, 'top_p': 0.9, 'n': 1} {"prompt": "Today is a beautiful day and I want to", "temperature": 0.9, "max_tokens": 50, "min_tokens": 4, "top_p": 0.9, "n": 1} /anaconda/envs/azureml_py38/lib/python3.8/site-packages/urllib3/connectionpool.py:1013: InsecureRequestWarning: Unverified HTTPS request is being made to host '127.0.0.1'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html#ssl-warnings warnings.warn( {'choices': [{'logprobs': {'finish_reason': 'length', 'text_offset': [0, 13, 18, 31, 33, 39, 46, 59, 61, 65, 71, 75, 88, 93, 101, 105, 112, 123, 127, 134, 145, 149, 154, 167, 171, 177, 180, 186, 190, 194, 204, 208, 216, 219, 224, 237, 245, 249, 264, 277, 290, 293, 296, 304, 311, 312, 322, 330, 333, 337, 343], 'token_logprobs': [114.51757049560547, -7.202858924865723, -2.5938339233398438, -8.301254272460938, -10.47624397277832, -7.893651962280273, -1.8497161865234375, -8.071701049804688, -1.6325111389160156, -5.895839691162109, -8.90438461303711, -2.484539031982422, -9.372077941894531, -8.116035461425781, -1.735626220703125, -2.1011886596679688, -8.79058837890625, -2.2017059326171875, -9.367317199707031, -7.027656555175781, -8.895431518554688, -8.451400756835938, -1.3462677001953125, -9.93267822265625, -10.011550903320312, -10.632247924804688, -8.647872924804688, -10.264923095703125, -10.580062866210938, -10.812957763671875, -2.2966461181640625, -10.242782592773438, -8.285675048828125, -8.37872314453125, -1.4456329345703125, -4.07470703125, -9.0941162109375, -8.21484375, -6.838775634765625, -2.443695068359375, -8.213226318359375, -8.910308837890625, -6.286163330078125, -8.563323974609375, -10.56024169921875, -10.562957763671875, -7.673858642578125, -1.420654296875, -8.224517822265625, -5.957977294921875], 'tokens': [' Consequently', ' Sham', ' Consequently', 'LI', ' assum', ' decide', ' Consequently', 'ST', 'igun', ' dress', ' pet', ' Consequently', 'archs', ' Curious', 'igun', ' sitcom', ' mistakenly', 'igun', ' banter', ' strategist', ' EQU', ' Moss', ' Consequently', ' Git', 'cember', ' sd', ' quiet', ' Hyp', ' est', ' amplified', 'igun', ' marches', ' Bu', 'juven', ' Consequently', ' wrongly', ' Eng', ' administrators', ' geopolitical', ' Consequently', 'iston', 'sometimes', ' method', '�', ' Lithuania', ' Tempest', 'agh', 'igun', ' Admin', "''''"], 'top_logprobs': None}, 'text': " Consequently Sham ConsequentlyLI assum decide ConsequentlySTigun dress pet Consequentlyarchs Curiousigun sitcom mistakenlyigun banter strategist EQU Moss Consequently Gitcember sd quiet Hyp est amplifiedigun marches Bujuven Consequently wrongly Eng administrators geopolitical Consequentlyistonsometimes method� Lithuania Tempestaghigun Admin''''"}], 'created': 1652293935, 'id': '968c7b8e-e455-45c3-9084-f48fcb3e8f7f', 'model': '/home/azureuser/opt-models/175B/reshard_no_os/reshard.pt', 'object': 'text_completion'} azureuser@rparik4:~/metaseq$ vi query_api.py azureuser@rparik4:~/metaseq$ python query_api.py {'prompt': 'The capital of France is ', 'temperature': 0.9, 'max_tokens': 50, 'min_tokens': 4, 'top_p': 0.9, 'n': 1} {"prompt": "The capital of France is ", "temperature": 0.9, "max_tokens": 50, "min_tokens": 4, "top_p": 0.9, "n": 1} /anaconda/envs/azureml_py38/lib/python3.8/site-packages/urllib3/connectionpool.py:1013: InsecureRequestWarning: Unverified HTTPS request is being made to host '127.0.0.1'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html#ssl-warnings warnings.warn( {'choices': [{'logprobs': {'finish_reason': 'length', 'text_offset': [0, 5, 18, 22, 35, 48, 59, 70, 83, 91, 97, 104, 111, 117, 121, 125, 129, 136, 142, 145, 147, 152, 154, 157, 167, 171, 176, 177, 180, 183, 189, 192, 194, 204, 207, 211, 214, 220, 223, 225, 231, 235, 239, 246, 252, 254, 260, 267, 274, 280, 287, 294, 299, 306, 312, 316, 321], 'token_logprobs': [80.31343841552734, -0.7621917724609375, -1.4765005111694336, -1.241154670715332, -1.2145938873291016, -9.56700611114502, -9.551746368408203, -0.8779163360595703, -9.419927597045898, -5.374179840087891, -3.9672622680664062, -6.8210906982421875, -2.652568817138672, -2.2682838439941406, -7.440319061279297, -0.7049484252929688, -1.1936416625976562, -2.1783599853515625, -2.3951492309570312, -11.061538696289062, -9.198898315429688, -9.94451904296875, -1.2152099609375, -5.336662292480469, -5.986625671386719, -0.7034378051757812, -0.923004150390625, -8.768692016601562, -1.239227294921875, -2.347686767578125, -1.4197845458984375, -1.463165283203125, -3.2377471923828125, -1.4078216552734375, -2.04815673828125, -0.8631134033203125, -3.471221923828125, -4.7039031982421875, -0.9492645263671875, -0.837005615234375, -1.57489013671875, -2.001678466796875, -0.9036407470703125, -2.2828369140625, -2.796630859375, -1.16583251953125, -0.9195098876953125, -1.4211883544921875, -1.052032470703125, -1.6554107666015625], 'tokens': ['rites', ' Consequently', 'igun', ' Consequently', ' Consequently', ' indigenous', ' ingredient', ' Consequently', ' realism', ' Shane', ' Campus', ' wicked', ' Shane', 'eneg', 'core', 'eneg', ' quests', ' quests', 'eneg', ' raids', 'hack', ' Bluetooth', 'igun', ' browse', 'shake', ' quests', 'eneg', ' slightest', ' flair', 'eneg', ' quests', 'eneg', ' Shane', 'eneg', 'eneg', ' quests', ' Shane', 'NW', ' Shane', ' quests', ' quests', ' Shane', ' quests', ' quests', 'Pract', ' quests', ' Shane', 'eneg', 'Pract', 'eneg'], 'top_logprobs': None}, 'text': 'rites Consequentlyigun Consequently Consequently indigenous ingredient Consequently realism Shane Campus wicked Shaneenegcoreeneg quests questseneg raidshack Bluetoothigun browseshake questseneg slightest flaireneg questseneg Shaneenegeneg quests ShaneNW Shane quests quests Shane quests questsPract quests ShaneenegPracteneg'}], 'created': 1652293983, 'id': '36bd5045-1ff1-47fa-a017-ba45fe4becf2', 'model': '/home/azureuser/opt-models/175B/reshard_no_os/reshard.pt', 'object': 'text_completion'}

suchenzang commented 2 years ago

Do your apex installations complete successfully? What do you see when you run from apex.normalization import FusedLayerNorm in your python shell?

lorr1 commented 2 years ago

I had the same issue for 1.3B. What I did to fix it is checkout the ae0b844c1f6725c3433a95e42cac760b3885170b commit of Megatron-LM and reinstall it. Then I manually went through the modules like multiheaded_attention.py and removed all init args of dtype=dtype. (If you don't, you'll get an error that dtype isn't expected as an init arg). Once all errors pass, things worked for me and the model generated not garbage. I took inspiration from this post https://github.com/facebookresearch/metaseq/pull/60#issuecomment-1120439309.