Open matteuscruz opened 3 days ago
Which example did you run when encountering this error? I'm not able to reproduce it, the following code runs without raising such error:
import torch
from fl_sim.models import ResNet10
from fl_sim.optimizers.base import ProxSGD_VR
model = ResNet10(10)
optim = ProxSGD_VR(model.parameters())
model_cache = ResNet10(10)
criterion = torch.nn.CrossEntropyLoss()
images, labels = torch.rand((2,3,224,224)), torch.tensor([2,4])
loss = criterion(model(images), labels)
loss.backward()
optim.step(local_weights=model_cache.parameters())
print("Test passed")
I receiving the following error:
TypeError: ProxSGD_VR.step() got multiple values for argument 'variance_buffer'