facebookresearch / ppuda

Code for Parameter Prediction for Unseen Deep Architectures (NeurIPS 2021)
MIT License
483 stars 60 forks source link

Training the GHN on other datasets #3

Open pivettamarcos opened 2 years ago

pivettamarcos commented 2 years ago

Is there a way to train the GHN on other types of datasets, such as with 1D inputs?

bknyaz commented 2 years ago

The overall approach of GHNs should work on 1D inputs as well, but in this code it's not supported.

The main steps to achieve that would be to:

  1. write your own net_generator.py based on ours and generate training architectures that can process 1D inputs.
  2. write a Network class to process 1D inputs.
  3. set spatial dimensions to 1 in max_shape https://github.com/facebookresearch/ppuda/blob/main/ppuda/config.py#L160

Other minor steps may be required as we assume 2D inputs in the code.

SpaceDorgi commented 2 years ago

Will you be providing more examples with different modalities like text or audio?

bknyaz commented 2 years ago

I just created a pull request https://github.com/facebookresearch/ppuda/pull/5 with an example to predict parameters for a generic MLP, which should be possible to adapt to 1D inputs, text or audio. However, the predicted parameters are very likely to be meaningless in this case because the GHN was trained on images, but this is just an example.

To make predicted parameters useful for 1D inputs, text or audio, GHN must be trained on such data. This requires research, but I'm hopeful that it will be possible in the near future.

Feel free to close this issue if your questions are resolved.

rsk97 commented 2 years ago

@bknyaz - If I want to train for CelebA dataset, I would have to generate new NN using the generator, and edit the network class to handle the celebA inputs?

bknyaz commented 2 years ago

You can generate new NNs to handle CelebA, but it may be easier to just change the existing CIFAR-10/ImageNet graphs on the fly in the graph loader (perhaps, somewhere in this function https://github.com/facebookresearch/ppuda/blob/main/ppuda/deepnets1m/loader.py#L167) by replacing the classification nodes with those appropriate for CelebA.

rsk97 commented 2 years ago

I see, yes that makes sense! Thanks for the super quick response!