sdv-dev / CTGAN

Conditional GAN for generating synthetic tabular data.
Other
1.27k stars 287 forks source link

Torch 2.0 fails with cuda=False #288

Closed amontanez24 closed 1 year ago

amontanez24 commented 1 year ago

Environment Details

Please indicate the following details about the environment in which you found the bug:

Error Description

With Torch 2.0, the demo code fails whenever the cuda parameter is set to False. This is a problem because it is False by default in Linux.

Steps to reproduce

data = load_demo()
discrete_columns = [
    'workclass',
    'education',
    'marital-status',
    'occupation',
    'relationship',
    'race',
    'sex',
    'native-country',
    'income'
]

ctgan = CTGAN(epochs=10, cuda=False)
ctgan.fit(data, discrete_columns)

# Synthetic copy
samples = ctgan.sample(1000)