Genentech / gReLU

gReLU is a python library to train, interpret, and apply deep learning models to DNA sequences.
https://genentech.github.io/gReLU/
MIT License
228 stars 23 forks source link

Making train_params and model_params functional to add documentation and avoid duplication #38

Open gokceneraslan opened 3 months ago

gokceneraslan commented 3 months ago

Right now it's hard to document model_params and train_params because they are simply big dictionaries we copy from notebook to notebook, e.g.:

model_params = {
    'model_type':'BorzoiPretrainedModel', # Type of model
    'n_tasks': 1, # Number of cell types to predict
    'crop_len':0, # No cropping of the model output
    'n_transformers': 0, # Number of transformer layers; the published Enformer model has 11
}

train_params = {
    'task':'binary', # binary classification
    'lr':1e-4, # learning rate
    'logger': 'csv', # Logs will be written to a CSV file
    'batch_size': 1024,
    'num_workers': 8,
    'devices': 1, # GPU index
    'save_dir': experiment,
    'optimizer': 'adam',
    'max_epochs': 5,
    'checkpoint': True, # Save checkpoints
}

import grelu.lightning
model = grelu.lightning.LightningModel(model_params=model_params, train_params=train_params)

This also makes things highly redundant. I was wondering if we can simply move these to very simple functions like make_train_params() and make_model_params() where the dictionaries above are the default arguments and they simply return these dictionaries. They can be overwritten by the user and should have docstrings just like other functions. That'd solve both the lack of documentation problem and the redundancy problem.

Final thing would look like:

model = grelu.lightning.LightningModel(
    model_params=make_model_params(model_type='BorzoiPretrainedModel'),
    train_params=make_train_params(lr=1e-5, devices=2),
)

Let me know if this makes sense.

gokceneraslan commented 3 months ago

Probably better to have separate model_params functions per model_type because they will have different types of params e.g. DilatedConvModel won't have n_transformers.

ekageyama commented 3 months ago

I see the reasoning for this, but is there a reason why you want to have them in a function? Wouldn't it be more simple to have a json/yaml file in a directory with models and load them directly from there and overwrite the params of the model? If you make a function then it becomes gRelu specific instead of a more generic solution.

gokceneraslan commented 3 months ago
ekageyama commented 3 months ago

How i would do it is I would create a pydantic model for your models, this would validate your model params. If you want a function that gives said params, the function reads from a default file and spits the new params, you could add functionality that replaces some elements if you give them key value pairs, and runs the validation after to ensure the new params are correct. The nice thing of this approach is that the pydantic models can be exported and you guarantee that the params are valid.

dagarfield commented 3 months ago

My two cents for what it’s worth. One very nice feature of gReLU is that is self contained and easy to use. I worry it becomes more steps for users to generate errors if you move, eg, to JSON files.

Sent from Gmail Mobile

On Mon 5. Aug 2024 at 10:35, ekageyama @.***> wrote:

How i would do it is I would create a pydantic model for your models, this would validate your model params. If you want a function that gives said params, the function reads from a default file and spits the new params, you could add functionality that replaces some elements if you give them key value pairs, and runs the validation after to ensure the new params are correct. The nice thing of this approach is that the pydantic models can be exported and you guarantee that the params are valid.

— Reply to this email directly, view it on GitHub https://github.com/Genentech/gReLU/issues/38#issuecomment-2269572288, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACQMKU3J6IVWNMZ5EBWRMH3ZP6ZUZAVCNFSM6AAAAABL7KZGB6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENRZGU3TEMRYHA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

ekageyama commented 3 months ago

In theory a new user wouldnt touch the jsons, as gocken mentioned it would be something like get_mode_param("DilatedConvModel") , and this would read and validate the json, and give you your param dictionary. But if you give someone else a model, they can ingest it and validate it. This will enforce that models are complete and the values correct. I think that a validation error is better than a pytorch one