HewlettPackard / swarm-learning

A simplified library for decentralized, privacy preserving machine learning
Apache License 2.0
333 stars 101 forks source link

pyt.py updates to reset model back to train mode after loss computation #213

Closed RadhakrishnaJ closed 9 months ago

RadhakrishnaJ commented 9 months ago

pyt.tf client module setting model mode to test for computing loss. It needs to be set back to train mode, otherwise user code might get into issues for continuous training process where user doesn't set model back to train mode.
Observed the issue while testing differential privacy in pytorch. During differential privacy optimizer expects model to be in train mode, but due to this issue in client code it was throwing an issues. ISSUE:
File "model/cifar_first_stage.py", line 129, in doTrainBatchPvc
loss.backward()
File "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py", line 489, in backward