sdv-dev / CTGAN

Conditional GAN for generating synthetic tabular data.
Other
1.25k stars 282 forks source link

Add verbose parameter to TVAE #269

Closed candalfigomoro closed 1 year ago

candalfigomoro commented 1 year ago

It seems not to be possible to track the progress of the fit using TVAE.

This is extremely problematic because, especially with large datasets, it is impossible to tell what epoch the model training has reached. Perhaps it is 1 hour away from the end of the training? Perhaps it is 100 years away? Who knows.

Please add a "verbose" parameter to TVAE (as already exists for CTGAN) so that the current epoch and loss can be printed.

Thank you.

candalfigomoro commented 1 year ago

A workaround is to monkey patch the CTGAN library, e.g.

from ctgan.synthesizers.tvae import *
from ctgan.synthesizers.tvae import _loss_function 

@random_state
def fit(self, train_data, discrete_columns=()):
    self.transformer = DataTransformer()
    self.transformer.fit(train_data, discrete_columns)
    train_data = self.transformer.transform(train_data)
    dataset = TensorDataset(torch.from_numpy(train_data.astype('float32')).to(self._device))
    loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)

    data_dim = self.transformer.output_dimensions
    encoder = Encoder(data_dim, self.compress_dims, self.embedding_dim).to(self._device)
    self.decoder = Decoder(self.embedding_dim, self.decompress_dims, data_dim).to(self._device)
    optimizerAE = Adam(
        list(encoder.parameters()) + list(self.decoder.parameters()),
        weight_decay=self.l2scale)

    for i in range(self.epochs):
        for id_, data in enumerate(loader):
            optimizerAE.zero_grad()
            real = data[0].to(self._device)
            mu, std, logvar = encoder(real)
            eps = torch.randn_like(std)
            emb = eps * std + mu
            rec, sigmas = self.decoder(emb)
            loss_1, loss_2 = _loss_function(
                rec, real, sigmas, mu, logvar,
                self.transformer.output_info_list, self.loss_factor
            )
            loss = loss_1 + loss_2
            loss.backward()
            optimizerAE.step()
            self.decoder.sigma.data.clamp_(0.01, 1.0)
        print(f'Epoch {i+1}, Loss: {loss.detach().cpu(): .4f}', flush=True)  # Patch

TVAE.fit = fit

Notice the additional print (original method: https://github.com/sdv-dev/CTGAN/blob/master/ctgan/synthesizers/tvae.py#L137)

npatki commented 1 year ago

Hi @candalfigomoro thanks for the feature request. I can see how this will be useful for large datasets over many epochs! We'll keep this issue open to track updates and communicate any progress.

I'm moving this issue to the CTGAN library, where we'll have to make the update (as shown in your code). Following that, we would be able to expose the verbose parameter through the SDV library's wrappers.

npatki commented 1 year ago

Hi everyone, we have issue #300 with the proposed functionality that we want to add. I'll mark this issue as a duplicate.

Feel free to look over #300 and let us know if that functionality will suit your needs. Thanks!