facebookresearch / vissl

VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
https://vissl.ai
MIT License
3.26k stars 334 forks source link

Register custom models other than BaseSSLMultiInputOutputModel #411

Closed mayalenE closed 3 years ago

mayalenE commented 3 years ago

🚀 Feature

Ability to register custom models other than BaseSSLMultiInputOutputModel.

When creating a custom model, we can register new trunk and head(s) but the forward pass of the BaseSSLMultiInputOutputModel do not seem to handle all cases.

Motivation & Examples

For instance, I would like to have the trunk receiving a single input and returning a dictionary, and a (single) head receiving this dictionary and returning a dictionary as well. In this case, in the method heads_forward(self, feats, heads) I receive a message error as len(feats) > 1 and len(heads) == 1. Simple example use case where you need this is the implementation of a VAE, where the trunk must return (z,mu,logvar) triplet but there is only one head decoder. More advanced use case that I have in mind is to implement a head with connection layers to the trunk, again this needs a trunk producing multiple outputs from different layers and the head to receive this dict/list of features.

I am able to modify the heads_forward method of BaseSSLMultiInputOutputModel to make it work but I would rather not modify the base classes that vissl provides and write my custom modules separately.

Thank you!

iseessel commented 3 years ago

Hi @mayalenE thanks for the request!

We will look into implementing this -- but If you are interested in contributing, I think this should be relatively straight forward to do.

I would recommend following the pattern in https://github.com/facebookresearch/vissl/blob/master/vissl/trainer/train_steps/__init__.py#L22 for the #build_model method. https://github.com/facebookresearch/vissl/blob/master/vissl/models/__init__.py#L13. Let me know if you are interested and I am happy to guide you.

mayalenE commented 3 years ago

Hi @iseessel and thank you for the response,

sure i'll be very happy to do that, I will try to make a pull request soon then and come back to you if I have any troubles, let me know If you have specific good practices/guide to follow.

iseessel commented 3 years ago

@mayalenE Great! Take a look at: https://github.com/facebookresearch/vissl/blob/master/.github/CONTRIBUTING.md, you'll have to pass the circle-ci build, which includes some tests and a linter -- here's info on how to run the linter locally. https://github.com/facebookresearch/vissl/blob/master/dev/README.md#option-3-use-python-devlint_commitpy.

mayalenE commented 3 years ago

@iseessel Thank you I'm almost done I am just not sure about the documentation part:

Q1: Should I change the documentation? I've done small changes in the models/init for the build_model() function to call any registered model (with the register_model function from classy vision). => Potential changes with respect to the current documentation could happen in vissl_modules/models and extend_modules/models

Q2: If yes how to do that? I dont see any guidelines in the CONTRIBUTING.md

iseessel commented 3 years ago

Hi @mayalenE sorry for the delay.

Yup a documentation change would be great. The docs are source controlled in the same repo. See for example models.rst

iseessel commented 3 years ago

We have added this here: https://github.com/facebookresearch/vissl/pull/416, as well as documentation here: https://vissl.readthedocs.io/en/v0.1.6/extend_modules/models.html#adding-new-base-model.

We only suggest adding a new base model if adding a new TRUNK or new HEAD does not work with your use-case.