tsrobinson / SyGNet

Synthetic data using Generative Adversarial Networks
GNU General Public License v3.0
11 stars 2 forks source link

Tune #7

Closed tsrobinson closed 2 years ago

tsrobinson commented 2 years ago

Create function that performs hyperparameter tuning for a GAN given input data.

Should have the following arguments:

Should also come with a warning about time constraints -- could takes hours/days!

tsrobinson commented 2 years ago

An untested version of this function was added in f8e7551f1bf7dfd96f3b720fc315b072bff85c10. @ayn2 to test.

tsrobinson commented 2 years ago

Re. #11, we could create a default option for evaluating model performance using critic score when model mode="wgan" or "cgan"

tsrobinson commented 2 years ago

We could also add functionality to evaluate performance mid-training, and adjust the output dataframe to include metrics per reporting interval, rather than per tested model.

I.e.

Trial Epoch Dropout ... Layer_struc Test output
1 10 0.1 ... [256,256] 35
...
1 100 0.1 ... [256,256] 32
2 10 0.3 ... [512,256] 47
tsrobinson commented 2 years ago

Progress:

Still to do:

tsrobinson commented 2 years ago

With respect to checkpointing @ayn2, I thought the basic pipeline would be something like:

# User specifies:
epochs = 100
checkpoints = 4

# Calculate epochs per cycle
cycle_epochs = epochs // checkpoints

# Then we loop through the cycles:
for cycle in range(run_cycles):
   ...
   model.fit(..., epochs = cycle_epochs)
   ...

We need to make sure this works with the k-fold validation -- you don't want these models to "contaminate" each other. I presume we can do this by having the cycle loop inside the k loop.