microsoft / tutel

Tutel MoE: An Optimized Mixture-of-Experts Implementation
MIT License
723 stars 93 forks source link

New Tutel checkpoint loading is incompatible with old models #186

Closed jinga-lala closed 2 years ago

jinga-lala commented 2 years ago

Hi, I have been using Swin-MoE pre-trained models with Tutel. However, after the recent update in Tutel library in model loading format, the pre-trained model has different dict structure than the current required expert model resulting in loading error. Can you please create compatible versions of these released pre-trained models? or release any script to do so? Any help would be highly appreciated.

ghostplant commented 2 years ago

Is this commit bffee7c compatible with old checkpoint states?

jinga-lala commented 2 years ago

@ghostplant Thanks for replying. I am afraid it's not.

ghostplant commented 2 years ago

@zeliu98 is working to detect and fixing it in latest version. Those SWIN checkpoint is just confirmed to be compatible with this early commit: 1d56b9b.

zeliu98 commented 2 years ago

Hi @jinga-lala, could you share the detailed error message? I have tried the latest tutel and everything is ok.

jinga-lala commented 2 years ago

Hi @zeliu98, sorry for late response. Here is the error message

2022-09-18 21:05:14,093 - mmdet - INFO - load model from: $HOME/pretrained_models/swin_moe_small_patch4_window12_192_8expert_32gpu_22k/swin_moe_small_patch4_window12_192_8expert_32gpu_22k.pth
WARNING:root:You are loading a legacy format of checkpoint with at least one Tutel MoE layer inside, which wouldn't support new Tutel feature allowing the number of experts per checkpoint file to mutate.
WARNING:root:  The next time you overwrite it with new checkpoint, the recording format will be updated automatically.
WARNING:root:  However, the new format won't be compatible with early Tutel versions, unless you force loading it with `model.load_state_dict(.., strict=False)`.
Traceback (most recent call last):
  File "$HOME/miniconda3/envs/test_env/lib/python3.9/site-packages/mmcv/utils/registry.py", line 51, in build_from_cfg
    return obj_cls(**args)
  File "$HOME/Swin-Transformer-Object-Detection/mmdet/models/detectors/cascade_rcnn.py", line 18, in __init__
    super(CascadeRCNN, self).__init__(
  File "$HOME/Swin-Transformer-Object-Detection/mmdet/models/detectors/two_stage.py", line 50, in __init__
    self.init_weights(pretrained=pretrained)
  File "$HOME/Swin-Transformer-Object-Detection/mmdet/models/detectors/two_stage.py", line 70, in init_weights
    self.backbone.init_weights(pretrained=pretrained)
  File "$HOME/Swin-Transformer-Object-Detection/mmdet/models/backbones/swin_transformer_moe.py", line 997, in init_weights
    load_checkpoint(self, pretrained, strict=False, logger=logger, is_moe=True)
  File "$HOME/Swin-Transformer-Object-Detection/mmcv_custom/checkpoint.py", line 363, in load_checkpoint
    load_state_dict(model, state_dict, strict, logger)
  File "$HOME/Swin-Transformer-Object-Detection/mmcv_custom/checkpoint.py", line 82, in load_state_dict
    load(module)
  File "$HOME/Swin-Transformer-Object-Detection/mmcv_custom/checkpoint.py", line 80, in load
    load(child, prefix + name + '.')
  File "$HOME/Swin-Transformer-Object-Detection/mmcv_custom/checkpoint.py", line 80, in load
    load(child, prefix + name + '.')
  File "$HOME/Swin-Transformer-Object-Detection/mmcv_custom/checkpoint.py", line 80, in load
    load(child, prefix + name + '.')
  [Previous line repeated 3 more times]
  File "$HOME/Swin-Transformer-Object-Detection/mmcv_custom/checkpoint.py", line 75, in load
    module._load_from_state_dict(state_dict, prefix, local_metadata, True,
  File "$HOME/Swin-Transformer-Object-Detection/mmdet/models/backbones/moe.py", line 55, in _load_from_state_dict
    assert buff_name in state_dict, "Could not find parameter `%s` in state_dict." % buff_name
AssertionError: Could not find parameter `layers.2.blocks.1.mlp._moe_layer.experts.batched_fc1_w` in state_dict.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "$HOME/Swin-Transformer-Object-Detection/tools/train.py", line 380, in <module>
    main()
  File "$HOME/Swin-Transformer-Object-Detection/tools/train.py", line 160, in main
    model = build_detector(
  File "$HOME/Swin-Transformer-Object-Detection/mmdet/models/builder.py", line 77, in build_detector
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
  File "$HOME/Swin-Transformer-Object-Detection/mmdet/models/builder.py", line 34, in build
    return build_from_cfg(cfg, registry, default_args)
  File "$HOME/miniconda3/envs/test_env/lib/python3.9/site-packages/mmcv/utils/registry.py", line 54, in build_from_cfg
    raise type(e)(f'{obj_cls.__name__}: {e}')
AssertionError: CascadeRCNN: Could not find parameter `layers.2.blocks.1.mlp._moe_layer.experts.batched_fc1_w` in state_dict.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 11580) of binary: $HOME/miniconda3/envs/test_env/bin/python
Traceback (most recent call last):
  File "$HOME/miniconda3/envs/test_env/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "$HOME/miniconda3/envs/test_env/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "$HOME/miniconda3/envs/test_env/lib/python3.9/site-packages/torch/distributed/launch.py", line 193, in <module>
    main()
  File "$HOME/miniconda3/envs/test_env/lib/python3.9/site-packages/torch/distributed/launch.py", line 189, in main
    launch(args)
  File "$HOME/miniconda3/envs/test_env/lib/python3.9/site-packages/torch/distributed/launch.py", line 174, in launch
    run(args)
  File "$HOME/miniconda3/envs/test_env/lib/python3.9/site-packages/torch/distributed/run.py", line 710, in run
    elastic_launch(
  File "$HOME/miniconda3/envs/test_env/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "$HOME/miniconda3/envs/test_env/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 259, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
tools/train.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2022-09-18_21:05:38
  host      : zatopek.cc.gatech.edu
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 11580)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
zeliu98 commented 2 years ago

Hi @jinga-lala, you may need to modify your code for loading a MoE model according to: https://github.com/microsoft/Swin-Transformer/blob/afeb877fba1139dfbc186276983af2abb02c2196/utils_moe.py#L76-L83.

jinga-lala commented 2 years ago

@zeliu98 Thank you for your help! It worked 🎉