adap / flower

Flower: A Friendly Federated Learning Framework
https://flower.ai
Apache License 2.0
4.92k stars 847 forks source link

Facing issue with Flower Simulation with ResNet18 and MNIST dataset #3237

Open EzyHow opened 5 months ago

EzyHow commented 5 months ago

Describe the bug

I was trying a example project of Flower Simulation (Flower Simulation Step by Step Pytorch - Part II). Everything went very well until I tried to change the model to resnet18 as given below:

class Net(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super(Net,` self).__init__()
        self.model = models.resnet18()
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, num_classes)
        summary(self.model, input_size=(1, 28, 28)) # <<== THIS LINE

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.model(x)
        return x

If I add summary(self.model, input_size=(1, 28, 28)) at the end of __init__() method, everything works. But when I remove it, I get error: input_param = input_param[0] IndexError: index 0 is out of bounds for dimension 0 with size 0 in evaluate_fn of server.py:

params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True) # <= At this line I'm getting error

Steps/Code to Reproduce

Clone the repository from Flower Simulation Step by Step Pytorch Part-II and follow instructions to setup the environment.

Then change the model to resnet18 in model.py file as given below:

import torch
import torch.nn as nn
import torchvision.models as models
from flwr.common.parameter import ndarrays_to_parameters
from torchsummary import summary

class Net(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super(Net, self).__init__()

        self.model = models.resnet18()
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, num_classes)
        summary(self.model, input_size=(1, 28, 28))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.model(x)
        return x

Following is the list of packages installed in the conda environment:

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
absl-py                   2.1.0                    pypi_0    pypi
aiohttp                   3.9.3                    pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
antlr4-python3-runtime    4.9.3                    pypi_0    pypi
asttokens                 2.4.1              pyhd8ed1ab_0    conda-forge
astunparse                1.6.3                    pypi_0    pypi
async-timeout             4.0.3                    pypi_0    pypi
attrs                     23.2.0                   pypi_0    pypi
blas                      1.0                         mkl  
brotli-python             1.0.9            py39h6a678d5_7  
bzip2                     1.0.8                h5eee18b_5  
ca-certificates           2024.3.11            h06a4308_0  
certifi                   2024.2.2           pyhd8ed1ab_0    conda-forge
cffi                      1.16.0                   pypi_0    pypi
charset-normalizer        2.0.4              pyhd3eb1b0_0  
click                     8.1.7                    pypi_0    pypi
comm                      0.2.2              pyhd8ed1ab_0    conda-forge
contourpy                 1.2.0                    pypi_0    pypi
cryptography              41.0.7                   pypi_0    pypi
cycler                    0.12.1                   pypi_0    pypi
datasets                  2.18.0                   pypi_0    pypi
debugpy                   1.6.7            py39h6a678d5_0  
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
dill                      0.3.8                    pypi_0    pypi
exceptiongroup            1.2.0              pyhd8ed1ab_2    conda-forge
executing                 2.0.1              pyhd8ed1ab_0    conda-forge
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.13.3                   pypi_0    pypi
flatbuffers               24.3.25                  pypi_0    pypi
flwr                      1.7.0                    pypi_0    pypi
flwr-datasets             0.1.0                    pypi_0    pypi
fonttools                 4.50.0                   pypi_0    pypi
freetype                  2.12.1               h4a9f257_0  
frozenlist                1.4.1                    pypi_0    pypi
fsspec                    2024.2.0                 pypi_0    pypi
gast                      0.5.4                    pypi_0    pypi
gmp                       6.2.1                h295c915_3  
gnutls                    3.6.15               he1e5248_0  
google-pasta              0.2.0                    pypi_0    pypi
grpcio                    1.62.1                   pypi_0    pypi
h5py                      3.10.0                   pypi_0    pypi
huggingface-hub           0.22.1                   pypi_0    pypi
hydra-core                1.3.2                    pypi_0    pypi
idna                      3.4              py39h06a4308_0  
importlib-metadata        7.1.0              pyha770c72_0    conda-forge
importlib-resources       6.4.0                    pypi_0    pypi
importlib_metadata        7.1.0                hd8ed1ab_0    conda-forge
intel-openmp              2023.1.0         hdb19cb5_46306  
ipykernel                 6.29.3             pyhd33586a_0    conda-forge
ipython                   8.18.1             pyh707e725_3    conda-forge
iterators                 0.0.2                    pypi_0    pypi
jedi                      0.19.1             pyhd8ed1ab_0    conda-forge
jpeg                      9e                   h5eee18b_1  
jsonschema                4.21.1                   pypi_0    pypi
jsonschema-specifications 2023.12.1                pypi_0    pypi
jupyter_client            8.6.1              pyhd8ed1ab_0    conda-forge
jupyter_core              5.7.2            py39hf3d152e_0    conda-forge
keras                     3.1.1                    pypi_0    pypi
kiwisolver                1.4.5                    pypi_0    pypi
lame                      3.100                h7b6447c_0  
lcms2                     2.12                 h3be6417_0  
ld_impl_linux-64          2.38                 h1181459_1  
lerc                      3.0                  h295c915_0  
libclang                  18.1.1                   pypi_0    pypi
libdeflate                1.17                 h5eee18b_1  
libffi                    3.4.4                h6a678d5_0  
libgcc-ng                 13.2.0               h807b86a_5    conda-forge
libgomp                   13.2.0               h807b86a_5    conda-forge
libiconv                  1.16                 h7f8727e_2  
libidn2                   2.3.4                h5eee18b_0  
libpng                    1.6.39               h5eee18b_0  
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libstdcxx-ng              11.2.0               h1234567_1  
libtasn1                  4.19.0               h5eee18b_0  
libtiff                   4.5.1                h6a678d5_0  
libunistring              0.9.10               h27cfd23_0  
libwebp-base              1.3.2                h5eee18b_0  
lz4-c                     1.9.4                h6a678d5_0  
markdown                  3.6                      pypi_0    pypi
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                2.1.5                    pypi_0    pypi
matplotlib                3.8.3                    pypi_0    pypi
matplotlib-inline         0.1.6              pyhd8ed1ab_0    conda-forge
mdurl                     0.1.2                    pypi_0    pypi
mkl                       2023.1.0         h213fc3f_46344  
mkl-service               2.4.0            py39h5eee18b_1  
mkl_fft                   1.3.8            py39h5eee18b_0  
mkl_random                1.2.4            py39hdb19cb5_0  
ml-dtypes                 0.3.2                    pypi_0    pypi
msgpack                   1.0.8                    pypi_0    pypi
multidict                 6.0.5                    pypi_0    pypi
multiprocess              0.70.16                  pypi_0    pypi
namex                     0.0.7                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
nest-asyncio              1.6.0              pyhd8ed1ab_0    conda-forge
nettle                    3.7.3                hbbd107a_1  
numpy                     1.26.4           py39h5f9d8c6_0  
numpy-base                1.26.4           py39hb5e798b_0  
omegaconf                 2.3.0                    pypi_0    pypi
openh264                  2.1.1                h4ff587b_0  
openjpeg                  2.4.0                h3ad879b_0  
openssl                   3.2.1                hd590300_1    conda-forge
opt-einsum                3.3.0                    pypi_0    pypi
optree                    0.11.0                   pypi_0    pypi
packaging                 24.0               pyhd8ed1ab_0    conda-forge
pandas                    2.2.1                    pypi_0    pypi
parso                     0.8.4              pyhd8ed1ab_0    conda-forge
pexpect                   4.9.0              pyhd8ed1ab_0    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    10.2.0           py39h5eee18b_0  
pip                       23.3.1           py39h06a4308_0  
platformdirs              4.2.0              pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.42             pyha770c72_0    conda-forge
protobuf                  4.25.3                   pypi_0    pypi
psutil                    5.9.8            py39hd1e30aa_0    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.2              pyhd8ed1ab_0    conda-forge
pyarrow                   15.0.2                   pypi_0    pypi
pyarrow-hotfix            0.6                      pypi_0    pypi
pycparser                 2.21                     pypi_0    pypi
pycryptodome              3.20.0                   pypi_0    pypi
pydantic                  1.10.14                  pypi_0    pypi
pygments                  2.17.2             pyhd8ed1ab_0    conda-forge
pyparsing                 3.1.2                    pypi_0    pypi
pysocks                   1.7.1            py39h06a4308_0  
python                    3.9.19               h955ad1f_0  
python-dateutil           2.9.0.post0              pypi_0    pypi
python_abi                3.9                      2_cp39    conda-forge
pytorch                   1.13.1              py3.9_cpu_0    pytorch
pytorch-mutex             1.0                         cpu    pytorch
pytz                      2024.1                   pypi_0    pypi
pyyaml                    6.0.1                    pypi_0    pypi
pyzmq                     25.1.2           py39h6a678d5_0  
ray                       2.6.3                    pypi_0    pypi
readline                  8.2                  h5eee18b_0  
referencing               0.34.0                   pypi_0    pypi
requests                  2.31.0           py39h06a4308_1  
rich                      13.7.1                   pypi_0    pypi
rpds-py                   0.18.0                   pypi_0    pypi
scipy                     1.12.0                   pypi_0    pypi
setuptools                68.2.2           py39h06a4308_0  
six                       1.16.0             pyh6c4a22f_0    conda-forge
sqlite                    3.41.2               h5eee18b_0  
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
tbb                       2021.8.0             hdb19cb5_0  
tensorboard               2.16.2                   pypi_0    pypi
tensorboard-data-server   0.7.2                    pypi_0    pypi
tensorflow-io-gcs-filesystem 0.36.0                   pypi_0    pypi
termcolor                 2.4.0                    pypi_0    pypi
tk                        8.6.12               h1ccaba5_0  
torchaudio                0.13.1                 py39_cpu    pytorch
torchsummary              1.5.1                    pypi_0    pypi
torchvision               0.14.1                 py39_cpu    pytorch
tornado                   6.4              py39hd1e30aa_0    conda-forge
tqdm                      4.66.2                   pypi_0    pypi
traitlets                 5.14.2             pyhd8ed1ab_0    conda-forge
typing_extensions         4.9.0            py39h06a4308_1  
tzdata                    2024.1                   pypi_0    pypi
urllib3                   2.1.0            py39h06a4308_1  
wcwidth                   0.2.13             pyhd8ed1ab_0    conda-forge
werkzeug                  3.0.2                    pypi_0    pypi
wheel                     0.41.2           py39h06a4308_0  
wrapt                     1.16.0                   pypi_0    pypi
xxhash                    3.4.1                    pypi_0    pypi
xz                        5.4.6                h5eee18b_0  
yarl                      1.9.4                    pypi_0    pypi
zeromq                    4.3.5                h6a678d5_0  
zipp                      3.18.1                   pypi_0    pypi
zlib                      1.2.13               h5eee18b_0  
zstd                      1.5.5                hc292b87_0 

requirement.txt file

datasets==2.18.0
flwr==1.7.0
hydra-core==1.3.2
omegaconf==2.3.0
torch==1.13.1
torchvision==0.14.1
flwr[simulation]>=1.0, <2.0
matplotlib==3.8.3
scipy==1.12.0
torchsummary==1.5.1

Expected Results

Following is the output when it runs successfully (by adding line summary(self.model, input_size=(1, 28, 28))) :

{'history': History (loss, distributed): round 1: 6.738090056180954 round 2: 3.8934330970048903 History (loss, centralized): round 0: 366.1482033729553 round 1: 97.4027541577816 round 2: 52.76616382226348 History (metrics, centralized): {'accuracy': [(0, 0.1086), (1, 0.8021), (2, 0.8959)]}

Actual Results

When I remove line summary(self.model, input_size=(1, 28, 28)), I get following error:

[2024-04-08 09:43:34,760][flwr][INFO] - Initializing global parameters
[2024-04-08 09:43:34,761][flwr][INFO] - Requesting initial parameters from one random client
[2024-04-08 09:43:37,337][flwr][INFO] - Received initial parameters from one random client
[2024-04-08 09:43:37,338][flwr][INFO] - Evaluating initial parameters
[2024-04-08 09:43:37,644][flwr][ERROR] - index 0 is out of bounds for dimension 0 with size 0
[2024-04-08 09:43:37,646][flwr][ERROR] - Traceback (most recent call last):
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/flwr/simulation/app.py", line 308, in start_simulation
    hist = run_fl(
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/flwr/server/app.py", line 225, in run_fl
    hist = server.fit(num_rounds=config.num_rounds, timeout=config.round_timeout)
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/flwr/server/server.py", line 92, in fit
    res = self.strategy.evaluate(0, parameters=self.parameters)
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/flwr/server/strategy/fedavg.py", line 165, in evaluate
    eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
  File "/root/development/machine-learning-project/server.py", line 42, in evaluate_fn
    model.load_state_dict(state_dict, strict=True)
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1657, in load_state_dict
    load(self, state_dict)
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1645, in load
    load(child, child_state_dict, child_prefix)
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1645, in load
    load(child, child_state_dict, child_prefix)
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1639, in load
    module._load_from_state_dict(
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 110, in _load_from_state_dict
    super(_NormBase, self)._load_from_state_dict(
  File "/root/miniconda3/envs/flower_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _load_from_state_dict
    input_param = input_param[0]
IndexError: index 0 is out of bounds for dimension 0 with size 0

[2024-04-08 09:43:37,648][flwr][ERROR] - Your simulation crashed :(. This could be because of several reasons. The most common are: 
     > Sometimes, issues in the simulation code itself can cause crashes. It's always a good idea to double-check your code for any potential bugs or inconsistencies that might be contributing to the problem. For example: 
         - You might be using a class attribute in your clients that hasn't been defined.
         - There could be an incorrect method call to a 3rd party library (e.g., PyTorch).
         - The return types of methods in your clients/strategies might be incorrect.
     > Your system couldn't fit a single VirtualClient: try lowering `client_resources`.
     > All the actors in your pool crashed. This could be because: 
         - You clients hit an out-of-memory (OOM) error and actors couldn't recover from it. Try launching your simulation with more generous `client_resources` setting (i.e. it seems {'num_cpus': 1, 'num_gpus': 0.0} is not enough for your run). Use fewer concurrent actors. 
         - You were running a multi-node simulation and all worker nodes disconnected. The head node might still be alive but cannot accommodate any actor with resources: {'num_cpus': 1, 'num_gpus': 0.0}.
Take a look at the Flower simulation examples for guidance <https://flower.dev/docs/framework/how-to-run-simulations.html>.
jafermarq commented 5 months ago

Hi @EzyHow, have you added that summary(self.model, input_size=(1, 28, 28)) somewhere else? maybe also in the evaluation in server.py? I wonder if torchsummary is adding something to the state_dict...

EzyHow commented 5 months ago

Flower Simulation Step by Step Pytorch Part-II

Kindly check this repository for detailed code: Testing Flower Simulation

In this repository, please go through the main.log files for three different scenarios given in output directory.

rhythm1827 commented 3 months ago

Hello,

I encountered the same issue and found a solution. I noticed the ndarrays_to_model function in src/model_utils.py. The relevant code is:

def ndarrays_to_model(model: torch.nn.ModuleList, params: List[np.ndarray]):
    """Set model weights from a list of NumPy ndarrays."""
    params_dict = zip(model.state_dict().keys(), params)
    state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)

Therefore, I changed

state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})

to

state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict})

in set_parameters function on client.py and evaluate_fn in server.py. Please also import numpy:

import numpy as np

I hope it will work for you as well.

TulioPolido commented 1 month ago

This worked for me. How did you come to this solution? I can't find a reason for it to work.

rhythm1827 commented 1 month ago

This worked for me. How did you come to this solution? I can't find a reason for it to work.

I am not sure but see one function use torch directly and another one using numpy. Maybe because of internal functions are different.