burlachenkok / flpytorch

FL_PyTorch: Optimization Research Simulator for Federated Learning
Apache License 2.0
35 stars 8 forks source link

Fix #12: PyTorch expects Long rather than Double or Float when computing CrossEntropyLoss #14

Closed techwizrd closed 1 year ago

techwizrd commented 1 year ago

This fixes the issue in utils.models_funcs.compute_loss where simulations crash when attempting to calculate the Cross Entropy Loss. The labels must be in Long format, but they're passed into the function as floating points by default.

Fixes #12