Open wittenator opened 1 month ago
Just to add that: I am talking about the fix that is mentioned in this comment: https://github.com/adap/flower/issues/3237#issuecomment-2145316689
Changing the line state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
to state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict})
or to state_dict = OrderedDict({k: torch.from_numpy(v).detach().clone() for k, v in params_dict})
fixes the error. (The second option does not need another import which is nice). torch.Tensor
does seem to copy the memory from the numpy buffer though, so I am not sure if memory ownership is actually the problem.
Hi @wittenator, yes using torch.from_numpy(...)
is the way to go. This is related to having a batchnorm layer that hasn't yet seen a single input. When using torch.Tensor()
it the num_batches_tracked
statistic will be in the form of:
('num_batches_tracked', tensor([]))
which isn't correct. But when using from_numpy(v)
it has the expected representation:
('num_batches_tracked', tensor(0))
All this being said, this part of the code isn't part of "Flower" strictly speaking. Since, depending on your model (or even ML framework of choice) you'd implement this functionality in one or other way.
Should we flag this issue as resolved? How about #3237 ?
[!NOTE] The recommended way of running Flower projects is via
flwr run
(e.g. as in examples/quickstart-pytorch](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) and many other examples). The python "entrypoint"run_simulation()
exists for now so simulations can run in setups like Colab/Jupyter. The set of features this way of running simulation support is lower.
Ah, that's very interesting! I wasn't aware of the this intricate difference between torch.Tensor
and torch.from_numpy
. Thanks for looking into this!
While this code piece is not strictly part of Flower, it still appears 79 times in 70 files across the code base (mainly old baselines and pretty much all examples for pytorch). The new baseline contains contains the fix with the np.copy, but I would agree that using torch.from_numpy
is cleaner. Since most people will run into this issue at some point, would you consider accepting a PR that replaces said line with the better version across the code base?
I'm currently trying out flwr run
, but I just wanted to demonstrate that already the very first tutorial from the website is broken once a model with a batchnorm is selected. :)
Since most people will run into this issue at some point, would you consider accepting a PR that replaces said line with the better version across the code base?
People have encountered it a few times indeed. Let me loop in @danieljanes and @yan-gao-GY: should we change all instances of:
def set_parameters(model, parameters):
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) # replace with torch.from_numpy()
model.load_state_dict(state_dict, strict=True)
@wittenator Thanks for raising this issue. @jafermarq It makes sense for me to replace with torch.from_numpy()
due to the potential crash caused by batchnorm layers.
@yan-gao-GY are there any consequences related to performance or something non-obvious when changing torch.Tensor()
to torch.from_numpy()
we should consider before making the change everywhere?
Describe the bug
Running the baselines with other models e.g. torchvision.models.resnet18 for fedprox/fednova/etc. fails with an out of bounds exception. This is the same problem that many people faced in e.g. #3237 when following the initial flower tutorial. Replacing the state dict fix across the whole code base seems to fix the problem, but I don't really see the reason why it works. Since the same problem appears in the very introductory tutorial, I would be really interested to discuss if implementing this across the code base is possible and what the exact reason/problem is that this change is fixing.
Steps/Code to Reproduce
Try following the tutorial at https://flower.ai/docs/framework/tutorial-series-get-started-with-flower-pytorch.html with resnet18 instead of the custom model.
Example code snippet from condensed Flower tutorial:
Expected Results
There should be no error when run with another model architecture.
Actual Results