LarsKue / lightning-trainable

A default trainable module for pytorch lightning.
MIT License
10 stars 1 forks source link

Added Gaussian Mixture Models as a toy distribution #28

Closed thelostscout closed 4 months ago

thelostscout commented 4 months ago

I thought it might be useful to have access to gaussian mixture model toy distributions

LarsKue commented 4 months ago

Thank you for the addition! Can you briefly explain in the docstring what this distribution/dataset looks like (e.g. in 2D), and how it differs from the Hypersphere dataset?

thelostscout commented 4 months ago

I don't think its too comparable to the hyperspheres from what I understand. In this case, the user controls the placement of all gaussian blobs as well as their weights and standard deviations. Do you think it would be sensible to reduce the complexity of the creation of the distributions through the reduction to high level arguments like the number of mixtures?

LarsKue commented 4 months ago

Yes, controlling the datasets via high-level hyperparameters in a similar fashion to how we construct models is the core philosophy of this library. Would you like to add this?

thelostscout commented 4 months ago

I think random generation will move it much closer towards the hyperspheres dataset, dependent on how the generation of the means and stddevs is implemented. It might make the addition obsolete.

LarsKue commented 4 months ago

In that case, let's stick to the hyperspheres dataset. You are welcome to add more generation modes to the hyperspheres dataset, though. For instance, we could replace it with something like

class MixtureDataset:
    def __init__(self, mode="spheres"):
        match mode:
            case "cubes": ...
            case "spheres": ...
            case "random": ...

where each mode changes the behaviour of the generation of the mean and std.

thelostscout commented 4 months ago

My issue with random creation is lacking reproducability. Say I wanted to learn a distribution with two gaussians and compare the loss values for different network types. In this case the loss will be different depending on the overlap and position of the means usually.

LarsKue commented 4 months ago

For this, you can either copy the dataset directly or use a seed (see lightning.seed_everything) before sampling.

thelostscout commented 4 months ago

Hmm, I was thinking of a method that uses its own rng to be able to seed the dataset generation without affecting the rest of the process. But maybe thats not necessary

thelostscout commented 4 months ago

In that case, let's stick to the hyperspheres dataset. You are welcome to add more generation modes to the hyperspheres dataset, though. For instance, we could replace it with something like

class MixtureDataset:
    def __init__(self, mode="spheres"):
        match mode:
            case "cubes": ...
            case "spheres": ...
            case "random": ...

where each mode changes the behaviour of the generation of the mean and std.

Yes, I can implement a hypercube version. I think I would place the blobs on the corners (and hence have an upper bound for the amount of centers). Would you consider a name change of the dataset to something like make_blobs or gmm? When I looked at the different datasets I assumed that hyperspheres would do something similar to hypershells, rather make a gaussian mixture model after a certain distribution rule.

LarsKue commented 4 months ago

Would you consider a name change of the dataset to something like make_blobs or gmm?

I would welcome a name change consistent with the implementation of changes. However, let's stick to ML jargon, using Dataset in place of Model where possible.

thelostscout commented 4 months ago

Ok, the trivial way would be to name it GaussianMixtureDataset. Or maybe MultiGaussianDataset? Or DistributedBlobsDataset?

thelostscout commented 4 months ago

Ok, see new pr