facebookresearch / optimizers

For optimization algorithm research and development.
Other
252 stars 24 forks source link

Empty params in FSDP cause issue #23

Open odegeasslbc opened 1 week ago

odegeasslbc commented 1 week ago

Hi all, first of all, thanks for your great work! I have issue when trying to use the optimizer with FSDP training.

The error is optimizer = DistributedShampoo( File "/root/slurm/src/optimizers/distributed_shampoo/distributed_shampoo.py", line 484, in __init__ self._instantiate_distributor() File "/root/slurm/conda/conda_sgm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/root/slurm/src/optimizers/distributed_shampoo/distributed_shampoo.py", line 519, in _instantiate_distributor state_lists[DISTRIBUTOR] = distributor(group) File "/root/slurm/src/optimizers/distributed_shampoo/utils/shampoo_fsdp_distributor.py", line 57, in __init__ super().__init__(param_group) File "/root/slurm/src/optimizers/distributed_shampoo/utils/shampoo_distributor.py", line 48, in __init__ self._merge_and_block_parameters() File "/root/slurm/src/optimizers/distributed_shampoo/utils/shampoo_fsdp_distributor.py", line 127, in _merge_and_block_parameters self._param_to_metadata[flattened_param].shape, KeyError: Parameter containing: tensor([ 0.0232, 0.0045, -0.0105, ..., -0.0004, -0.0194, -0.0039], device='cuda:0')

According to the error which was triggered here: https://github.com/facebookresearch/optimizers/blob/1685444a6de5d028c21480ae960e0867fac86b0c/distributed_shampoo/utils/shampoo_fsdp_distributor.py#L127

I'm using it with pytorch_lightning framwork and FSDP training, below is how I initialize the optimizer optimizer = DistributedShampoo( self.model.parameters(), lr=lr, betas=cfg.params.betas, epsilon=cfg.params.epsilon, weight_decay=cfg.params.weight_decay, max_preconditioner_dim=cfg.params.max_preconditioner_dim, precondition_frequency=cfg.params.precondition_frequency, use_decoupled_weight_decay=cfg.params.use_decoupled_weight_decay, grafting_config=AdamGraftingConfig( beta2=cfg.params.betas[1], epsilon=cfg.params.epsilon, ), distributed_config=FSDPShampooConfig( param_to_metadata=compile_fsdp_parameter_metadata(self.model), ), )

Do you have any guess of what could cause the error? Could it due to the way pytorch-lightning initialize FSDP model is different than the raw pytorch one as shown in the example?

Btw below is my FSDP setting: trainer_kwargs["strategy"] = FSDPStrategy( sharding_strategy=sharding_strategy, precision_plugin=precision_plugin, auto_wrap_policy=llm_policy, activation_checkpointing_policy=activation_checkpointing_policy, state_dict_type="full", #"full", #"sharded", limit_all_gathers=True, sync_module_states=True, # must be true to sync model parameters from rank0 to all ranks, beause we load models only on rank0! mixed_precision=bfSixteen, use_orig_params=True, )

tsunghsienlee commented 1 day ago

Hi @odegeasslbc ,

Unfortunately we did not experiment with PyTorch Lightning FSDP, just to confirm, could you run https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/examples/fsdp_cifar10_example.py which is running FSDP in vanilla PyTorch setting. We used this to verify our setup on FSDP, and if you could run this, the problem might be in PyTorch Lightning part.