adap / flower

Flower: A Friendly Federated AI Framework
https://flower.ai
Apache License 2.0
5.18k stars 885 forks source link

Inplace update error in loading state_dict for YOLOV8 #3084

Closed shubhamjanu closed 8 months ago

shubhamjanu commented 8 months ago

What is your question?

I am facing an issue while implementing federated learning with yolov8. The code runs fine upto first round of training for clients. After one round, when parameters are received from the server, the set_parameter() method throws an error which goes as :

Traceback (most recent call last):
  File "C:\Users\Administrator\Desktop\UTokyo\Japan51\client.py", line 34, in <module>
    fl.client.start_numpy_client(
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python312\Lib\site-packages\flwr\client\app.py", line 500, in start_numpy_client
    start_client(
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python312\Lib\site-packages\flwr\client\app.py", line 248, in start_client
    _start_client_internal(
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python312\Lib\site-packages\flwr\client\app.py", line 382, in _start_client_internal
    out_message = app(message=message, context=context)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python312\Lib\site-packages\flwr\client\flower.py", line 76, in __call__
    return self._call(message, context)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python312\Lib\site-packages\flwr\client\flower.py", line 66, in ffn
    out_message = handle_legacy_message_from_tasktype(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python312\Lib\site-packages\flwr\client\message_handler\message_handler.py", line 135, in handle_legacy_message_from_tasktype
    fit_res = maybe_call_fit(
              ^^^^^^^^^^^^^^^
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python312\Lib\site-packages\flwr\client\client.py", line 234, in maybe_call_fit
    return client.fit(fit_ins)
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python312\Lib\site-packages\flwr\client\numpy_client.py", line 238, in _fit
    results = self.numpy_client.fit(parameters, ins.config)  # type: ignore
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Administrator\Desktop\UTokyo\Japan51\client.py", line 22, in fit
    set_parameters(net, parameters)
  File "C:\Users\Administrator\Desktop\UTokyo\Japan51\client.py", line 10, in set_parameters
    net.load_state_dict(state_dict, strict=True)
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for YOLO:
        While copying the parameter named "model.model.0.conv.weight", whose dimensions in the model are torch.Size([48, 3, 3, 3]) and whose dimensions in the checkpoint are torch.Size([48, 3, 3, 3]), an exception occurred : ('Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See https://github.com/pytorch/rfcs/pull/17 for more details.',).

Basically it seems like inplace update error, but i am not able to resolve it. I have tried many ways like making clone, deepcopy etc. I am also attaching my code files: centralized.py, client.py and server.py as well as the error logs. centralized.txt client.txt error.txt server.txt

Please help, i am stuck at this issue since many days. If you need any additional info., let me know.

danieljanes commented 8 months ago

Hi @shubhamjanu, thanks for reporting this, we're looking into it.

In the meantime, can you try to disable in-place aggregation by using strategy = FedAvg(...[other arguments here]..., inplace=False)? I just noticed that we forgot to document this in the API reference, we'll fix this asap.

danieljanes commented 8 months ago

@shubhamjanu this PR adds the missing documentation: https://github.com/adap/flower/pull/3086

Let us know if this works without in-place aggregation.

shubhamjanu commented 8 months ago

@danieljanes Hello!, After including inplace=False in fedAvg() method, the code still did not work and gave the same inplace update error. Then, I removed evaluate() method completely from my client.py code. The code started working. Now, I am able to train my model for more than one round. After every round of training of yolov8 model, there is inherent self validation of the model with best weights. Therefore, I think I can do away with evaluate() method. What is your opinion?

shubhamjanu commented 8 months ago

@danieljanes Hi! I wish to report one strange observation. I was able to run my federated code with available CPU. Today, I installed CUDA 11.8 and tried to run the code. It is getting stuck and not able to start training. It is also trying "AMP: running Automatic Mixed Precision (AMP) checks with YOLOv8n...and Downloading https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8n.pt to 'yolov8n.pt'...". Which was not observed when training with CPU. To reproduce the issue, I again tried running with CPU only and it was working fine. I have installed cuda and cudaCNN compatible with pytorch after checking on official pytorch website and they were installed successfully without any warning or error message. I am using an AWS instance windows XP 2022 server. I am attaching my code files and detailed error log. The training does not start and then grpc channel closes after waiting for sometime. centralized.txt client.txt error.txt server.txt

Please guide how to resolve above issue.

danieljanes commented 8 months ago

great to hear it's working @shubhamjanu . i'm not familiar with the yolo v8 codebase, so i can't help with the other issue. you can probably get help by posting on slack (#questions) or flower discuss:

kvnptl commented 3 months ago

@shubhamjanu Were you able to fix the problem with GPU-based model training?