Open MakGulati opened 2 years ago
Thank you for filing & describing your use case @MakGulati. We'll keep this issue open and update it whenever we make progress.
An unsupported & hacky workaround you can try in the meantime: You can access the generator model using the model._generator
. This will get you PyTorch Module
object. From there on out, you'd have to refer to the PyTorch user guides & API to extract any desired parameters.
I would also need access to discriminator. As it is a local variable now, I cannot access it outside the class. @npatki With change in the dimensions of discriminator, it is also reload discriminator.
Hi everyone, would attributing the discriminator
to self
like the self._generator
be a potential solution here? It can maybe be optional with False
as a default?
Problem Description
Ability to access underlying deep learning model written in pytorch. It would be nice to have API function call to read and write weights of network (e.g. generator and discriminator network in case of CTGAN).
Expected behavior
API access to variables in CTGAN package through SDV package.
Additional context
Having such feature would allow to train federated learning models. After training, deep learning networks on different local clients with different local datasets in SDV using
fit
method the model weights can be read (through this new API) and sent for aggregation at global server and then global server sends the aggregated model which is then loaded on the selected clients at each round. Following that the local training starts again and the process repeats. Finally, after training for desired number of rounds, SDV'ssample
method can be used to generate synthetic datasets with underlying aggregated deep learning model which captures behaviors from different dateset holding local clients.