sdv-dev / SDV

Synthetic data generation for tabular data
https://docs.sdv.dev/sdv
Other
2.3k stars 303 forks source link

Access to underlying model weights of CTGAN and TVAE for manipulation #682

Open MakGulati opened 2 years ago

MakGulati commented 2 years ago

Problem Description

Ability to access underlying deep learning model written in pytorch. It would be nice to have API function call to read and write weights of network (e.g. generator and discriminator network in case of CTGAN).

Expected behavior

API access to variables in CTGAN package through SDV package.

Additional context

Having such feature would allow to train federated learning models. After training, deep learning networks on different local clients with different local datasets in SDV using fit method the model weights can be read (through this new API) and sent for aggregation at global server and then global server sends the aggregated model which is then loaded on the selected clients at each round. Following that the local training starts again and the process repeats. Finally, after training for desired number of rounds, SDV's sample method can be used to generate synthetic datasets with underlying aggregated deep learning model which captures behaviors from different dateset holding local clients.

npatki commented 2 years ago

Thank you for filing & describing your use case @MakGulati. We'll keep this issue open and update it whenever we make progress.

An unsupported & hacky workaround you can try in the meantime: You can access the generator model using the model._generator. This will get you PyTorch Module object. From there on out, you'd have to refer to the PyTorch user guides & API to extract any desired parameters.

MakGulati commented 2 years ago

I would also need access to discriminator. As it is a local variable now, I cannot access it outside the class. @npatki With change in the dimensions of discriminator, it is also reload discriminator.

AndresAlgaba commented 2 years ago

Hi everyone, would attributing the discriminator to self like the self._generator be a potential solution here? It can maybe be optional with False as a default?