adap / flower

Flower: A Friendly Federated Learning Framework
https://flower.ai
Apache License 2.0
4.83k stars 833 forks source link

ZeroDivisionError: integer division or modulo by zero #1791

Open sauravtii opened 1 year ago

sauravtii commented 1 year ago

Describe the bug

I am trying out this code but facing an error. Can I get some help?

My code:

from collections import OrderedDict
import warnings

import flwr as fl
import torch
import numpy as np

import random
from torch.utils.data import DataLoader

from datasets import load_dataset, load_metric

from transformers import AutoTokenizer, DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification
from transformers import AdamW
#from transformers import tokenized_datasets

warnings.filterwarnings("ignore", category=UserWarning)
# DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DEVICE = "cpu"

CHECKPOINT = "distilbert-base-uncased"  # transformer model checkpoint

def load_data():
    """Load IMDB data (training and eval)"""
    raw_datasets = load_dataset("yhavinga/imdb_dutch")
    raw_datasets = raw_datasets.shuffle(seed=42)

    # remove unnecessary data split
    del raw_datasets["unsupervised"]

    tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)

    def tokenize_function(examples):
        return tokenizer(examples["text"], truncation=True)

    # random 100 samples
    population = random.sample(range(len(raw_datasets["train"])), 100)

    tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
    tokenized_datasets["train"] = tokenized_datasets["train"].select(populatio)nw
    tokenized_datasets["test"] = tokenized_datasets["test"].select(population)

    # tokenized_datasets = tokenized_datasets.remove_columns("text")
    # tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    tokenized_datasets = tokenized_datasets.remove_columns("attention_mask")
    tokenized_datasets = tokenized_datasets.remove_columns("input_ids")
    tokenized_datasets = tokenized_datasets.remove_columns("label")
    # tokenized_datasets = tokenized_datasets.remove_columns("text_en")

    # tokenized_datasets = tokenized_datasets.remove_columns(raw_datasets["train"].column_names)

    tokenized_datasets = tokenized_datasets.remove_columns(["text", "text_en"])

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    trainloader = DataLoader(
        tokenized_datasets["train"],
        shuffle=True,
        batch_size=32,
        collate_fn=data_collator,
    )

    testloader = DataLoader(
        tokenized_datasets["test"], batch_size=32, collate_fn=data_collator
    )

    return trainloader, testloader

def train(net, trainloader, epochs):
    optimizer = AdamW(net.parameters(), lr=5e-4)
    net.train()
    for _ in range(epochs):
        for batch in trainloader:
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            outputs = net(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

def test(net, testloader):
    metric = load_metric("accuracy")
    loss = 0
    net.eval()
    for batch in testloader:
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        with torch.no_grad():
            outputs = net(**batch)
        logits = outputs.logits
        loss += outputs.loss.item()
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])
    loss /= len(testloader.dataset)
    accuracy = metric.compute()["accuracy"]
    return loss, accuracy

def main():
    net = AutoModelForSequenceClassification.from_pretrained(
        CHECKPOINT, num_labels=2
    ).to(DEVICE)

    trainloader, testloader = load_data()

    # Flower client
    class IMDBClient(fl.client.NumPyClient):
        def get_parameters(self, config):
            return [val.cpu().numpy() for _, val in net.state_dict().items()]

        def set_parameters(self, parameters):
            params_dict = zip(net.state_dict().keys(), parameters)
            state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
            net.load_state_dict(state_dict, strict=True)

        def fit(self, parameters, config):
            self.set_parameters(parameters)
            print("Training Started...")
            train(net, trainloader, epochs=1)
            print("Training Finished.")
            return self.get_parameters(config={}), len(trainloader), {}

        def evaluate(self, parameters, config):
            self.set_parameters(parameters)
            loss, accuracy = test(net, testloader)
            return float(loss), len(testloader), {"accuracy": float(accuracy)}

    # Start client
    fl.client.start_numpy_client(server_address="localhost:8080", client=IMDBClient())

if __name__ == "__main__":
    main()

Steps/Code to Reproduce

Just run the code provided above.

Expected Results

I don't know as I am running this code with a different setting for the first time.

Actual Results

Error:

Traceback (most recent call last):
  File "client_2.py", line 138, in <module>
    main()
  File "client_2.py", line 134, in main
    fl.client.start_numpy_client(server_address="localhost:8080", client=IMDBClient())
  File "/home/saurav/.local/lib/python3.8/site-packages/flwr/client/app.py", line 208, in start_numpy_client
    start_client(
  File "/home/saurav/.local/lib/python3.8/site-packages/flwr/client/app.py", line 142, in start_client
    client_message, sleep_duration, keep_going = handle(
  File "/home/saurav/.local/lib/python3.8/site-packages/flwr/client/grpc_client/message_handler.py", line 68, in handle
    return _fit(client, server_msg.fit_ins), 0, True
  File "/home/saurav/.local/lib/python3.8/site-packages/flwr/client/grpc_client/message_handler.py", line 157, in _fit
    fit_res = client.fit(fit_ins)
  File "/home/saurav/.local/lib/python3.8/site-packages/flwr/client/app.py", line 252, in _fit
    results = self.numpy_client.fit(parameters, ins.config)  # type: ignore
  File "client_2.py", line 124, in fit
    train(net, trainloader, epochs=1)
  File "client_2.py", line 78, in train
    for batch in trainloader:
  File "/home/saurav/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 652, in __next__
    data = self._next_data()
  File "/home/saurav/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 692, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/saurav/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/saurav/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/saurav/.local/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 1525, in __getitem__
    return self._getitem(
  File "/home/saurav/.local/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 1517, in _getitem
    pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
  File "/home/saurav/.local/lib/python3.8/site-packages/datasets/formatting/formatting.py", line 373, in query_table
    pa_subtable = _query_table_with_indices_mapping(table, key, indices=indices)
  File "/home/saurav/.local/lib/python3.8/site-packages/datasets/formatting/formatting.py", line 55, in _query_table_with_indices_mapping
    return _query_table(table, key)
  File "/home/saurav/.local/lib/python3.8/site-packages/datasets/formatting/formatting.py", line 79, in _query_table
    return table.fast_slice(key % table.num_rows, 1)
ZeroDivisionError: integer division or modulo by zero
tanertopal commented 1 year ago

@sauravtii does this relate to your question here: https://github.com/adap/flower/issues/1796

Unfortunately, we don't have the capacity to debug issues which are unrelated to Flower. I am keeping this open so you can clarify the question, and maybe someone from the community wants to help out.