sdv-dev / CTGAN

Conditional GAN for generating synthetic tabular data.
Other
1.28k stars 291 forks source link

Return loss values as float values not PyTorch objects #332

Closed srinify closed 8 months ago

srinify commented 8 months ago

Duplicate of https://github.com/sdv-dev/SDV/issues/1811, close both at same time.

Problem Description

After fitting a model, calling loss_values returns a DataFrame object where the loss values are PyTorch tensor objects instead of just simple float values.

Screenshot 2024-02-22 at 4 20 36 PM

This means that plotting these values requires an extra step of extracting the values using apply(), which adds unnecessary friction I feel.

Screenshot 2024-02-22 at 4 58 56 PM

Expected behavior

Ideally the returned DataFrame just had float values for Generator & Discriminator loss values. This lowers the friction for plotting the loss values:

loss_df = ctgan.loss_values
loss_df.plot(x='Epoch', y=['Generator Loss', 'Discriminator Loss'])

Additional context

Relevant code is here:

https://github.com/sdv-dev/CTGAN/blob/main/ctgan/synthesizers/ctgan.py#L426