adap / flower

Flower: A Friendly Federated AI Framework
https://flower.ai
Apache License 2.0
5.15k stars 881 forks source link

State dict creation bug for e.g. resnet18 #4344

Open wittenator opened 1 month ago

wittenator commented 1 month ago

Describe the bug

Running the baselines with other models e.g. torchvision.models.resnet18 for fedprox/fednova/etc. fails with an out of bounds exception. This is the same problem that many people faced in e.g. #3237 when following the initial flower tutorial. Replacing the state dict fix across the whole code base seems to fix the problem, but I don't really see the reason why it works. Since the same problem appears in the very introductory tutorial, I would be really interested to discuss if implementing this across the code base is possible and what the exact reason/problem is that this change is fixing.

Steps/Code to Reproduce

Try following the tutorial at https://flower.ai/docs/framework/tutorial-series-get-started-with-flower-pytorch.html with resnet18 instead of the custom model.

Example code snippet from condensed Flower tutorial:

from collections import OrderedDict
from typing import List, Tuple
import torchvision

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

DEVICE = torch.device("cuda")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

NUM_CLIENTS = 10
BATCH_SIZE = 32

def load_datasets(partition_id: int):
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": NUM_CLIENTS})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    # Create train/val for each partition and wrap it into DataLoader
    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(
        partition_train_test["train"], batch_size=BATCH_SIZE, shuffle=True
    )
    valloader = DataLoader(partition_train_test["test"], batch_size=BATCH_SIZE)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloader, valloader, testloader

def train(net, trainloader, epochs: int, verbose=False):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"].to(DEVICE), batch["label"].to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        if verbose:
            print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")

def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"].to(DEVICE), batch["label"].to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

class FlowerClient(NumPyClient):
    def __init__(self, net, trainloader, valloader):
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

def client_fn(context: Context) -> Client:
    """Create a Flower client representing a single organization."""

    # Load model
    net = torchvision.models.resnet18().to(DEVICE)

    # Load data (CIFAR-10)
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data partition
    # Read the node_config to fetch data partition associated to this node
    partition_id = context.node_config["partition-id"]
    trainloader, valloader, _ = load_datasets(partition_id=partition_id)

    # Create a single Flower client representing a single organization
    # FlowerClient is a subclass of NumPyClient, so we need to call .to_client()
    # to convert it to a subclass of `flwr.client.Client`
    return FlowerClient(net, trainloader, valloader).to_client()

# Create the ClientApp
client = ClientApp(client_fn=client_fn)

# Create FedAvg strategy
strategy = FedAvg(
    fraction_fit=1.0,  # Sample 100% of available clients for training
    fraction_evaluate=0.5,  # Sample 50% of available clients for evaluation
    min_fit_clients=10,  # Never sample less than 10 clients for training
    min_evaluate_clients=5,  # Never sample less than 5 clients for evaluation
    min_available_clients=10,  # Wait until all 10 clients are available
)

def server_fn(context: Context) -> ServerAppComponents:
    """Construct components that set the ServerApp behaviour.

    You can use the settings in `context.run_config` to parameterize the
    construction of all elements (e.g the strategy or the number of rounds)
    wrapped in the returned ServerAppComponents object.
    """

    # Configure the server for 5 rounds of training
    config = ServerConfig(num_rounds=5)

    return ServerAppComponents(strategy=strategy, config=config)

# Create the ServerApp
server = ServerApp(server_fn=server_fn)

# Specify the resources each of your clients need
# By default, each client will be allocated 1x CPU and 0x GPUs
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

# When running on GPU, assign an entire GPU for each client
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.1}}
    # Refer to our Flower framework documentation for more details about Flower simulations
    # and how to set up the `backend_config`

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)

Expected Results

There should be no error when run with another model architecture.

Actual Results

ray::ClientAppActor.run() (pid=8542, ip=172.28.0.12, actor_id=9f091c66b19bc09302af02a801000000, repr=<flwr.simulation.ray_transport.ray_actor.ClientAppActor object at 0x12d6fcb32320>)
  File "/usr/local/lib/python3.10/dist-packages/flwr/simulation/ray_transport/ray_actor.py", line 63, in run
    raise ClientAppException(str(ex)) from ex
flwr.client.client_app.ClientAppException: 
Exception ClientAppException occurred. Message: index 0 is out of bounds for dimension 0 with size 0
ERROR :     Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/flwr/server/superlink/fleet/vce/vce_api.py", line 112, in worker
    out_mssg, updated_context = backend.process_message(message, context)
  File "/usr/local/lib/python3.10/dist-packages/flwr/server/superlink/fleet/vce/backend/raybackend.py", line 186, in process_message
    raise ex
  File "/usr/local/lib/python3.10/dist-packages/flwr/server/superlink/fleet/vce/backend/raybackend.py", line 174, in process_message
    ) = self.pool.fetch_result_and_return_actor_to_pool(future)
  File "/usr/local/lib/python3.10/dist-packages/flwr/simulation/ray_transport/ray_actor.py", line 477, in fetch_result_and_return_actor_to_pool
    _, out_mssg, updated_context = ray.get(future)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2667, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 864, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ClientAppException): ray::ClientAppActor.run() (pid=8542, ip=172.28.0.12, actor_id=9f091c66b19bc09302af02a801000000, repr=<flwr.simulation.ray_transport.ray_actor.ClientAppActor object at 0x12d6fcb32320>)
  File "/usr/local/lib/python3.10/dist-packages/flwr/client/client_app.py", line 143, in __call__
    return self._call(message, context)
  File "/usr/local/lib/python3.10/dist-packages/flwr/client/client_app.py", line 126, in ffn
    out_message = handle_legacy_message_from_msgtype(
  File "/usr/local/lib/python3.10/dist-packages/flwr/client/message_handler/message_handler.py", line 129, in handle_legacy_message_from_msgtype
    fit_res = maybe_call_fit(
  File "/usr/local/lib/python3.10/dist-packages/flwr/client/client.py", line 255, in maybe_call_fit
    return client.fit(fit_ins)
  File "/usr/local/lib/python3.10/dist-packages/flwr/client/numpy_client.py", line 259, in _fit
    results = self.numpy_client.fit(parameters, ins.config)  # type: ignore
  File "<ipython-input-5-5299d2ab935a>", line 118, in fit
  File "<ipython-input-5-5299d2ab935a>", line 102, in set_parameters
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2564, in load_state_dict
    load(self, state_dict)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2552, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2535, in load
    module._load_from_state_dict(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/batchnorm.py", line 132, in _load_from_state_dict
    super()._load_from_state_dict(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2385, in _load_from_state_dict
    input_param = input_param[0]
IndexError: index 0 is out of bounds for dimension 0 with size 0
wittenator commented 1 month ago

Just to add that: I am talking about the fix that is mentioned in this comment: https://github.com/adap/flower/issues/3237#issuecomment-2145316689

wittenator commented 1 month ago

Changing the line state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) to state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict}) or to state_dict = OrderedDict({k: torch.from_numpy(v).detach().clone() for k, v in params_dict}) fixes the error. (The second option does not need another import which is nice). torch.Tensor does seem to copy the memory from the numpy buffer though, so I am not sure if memory ownership is actually the problem.

jafermarq commented 3 weeks ago

Hi @wittenator, yes using torch.from_numpy(...) is the way to go. This is related to having a batchnorm layer that hasn't yet seen a single input. When using torch.Tensor() it the num_batches_tracked statistic will be in the form of:

('num_batches_tracked', tensor([]))

which isn't correct. But when using from_numpy(v) it has the expected representation:

('num_batches_tracked', tensor(0))

All this being said, this part of the code isn't part of "Flower" strictly speaking. Since, depending on your model (or even ML framework of choice) you'd implement this functionality in one or other way.

Should we flag this issue as resolved? How about #3237 ?

[!NOTE] The recommended way of running Flower projects is via flwr run (e.g. as in examples/quickstart-pytorch](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) and many other examples). The python "entrypoint" run_simulation() exists for now so simulations can run in setups like Colab/Jupyter. The set of features this way of running simulation support is lower.

wittenator commented 3 weeks ago

Ah, that's very interesting! I wasn't aware of the this intricate difference between torch.Tensor and torch.from_numpy. Thanks for looking into this! While this code piece is not strictly part of Flower, it still appears 79 times in 70 files across the code base (mainly old baselines and pretty much all examples for pytorch). The new baseline contains contains the fix with the np.copy, but I would agree that using torch.from_numpy is cleaner. Since most people will run into this issue at some point, would you consider accepting a PR that replaces said line with the better version across the code base? I'm currently trying out flwr run, but I just wanted to demonstrate that already the very first tutorial from the website is broken once a model with a batchnorm is selected. :)

jafermarq commented 3 weeks ago

Since most people will run into this issue at some point, would you consider accepting a PR that replaces said line with the better version across the code base?

People have encountered it a few times indeed. Let me loop in @danieljanes and @yan-gao-GY: should we change all instances of:

def set_parameters(model, parameters):
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) # replace with torch.from_numpy()
    model.load_state_dict(state_dict, strict=True)
yan-gao-GY commented 3 weeks ago

@wittenator Thanks for raising this issue. @jafermarq It makes sense for me to replace with torch.from_numpy() due to the potential crash caused by batchnorm layers.

jafermarq commented 3 weeks ago

@yan-gao-GY are there any consequences related to performance or something non-obvious when changing torch.Tensor() to torch.from_numpy() we should consider before making the change everywhere?