SforAiDl / jeta

A Jax based meta learning library
MIT License
16 stars 12 forks source link

MAML is expecting more parameters than being passed #41

Open AnirudhM1 opened 2 years ago

AnirudhM1 commented 2 years ago

https://github.com/SforAiDl/jeta/blob/e55bf83c9f89c662872f50a1ef8885c1085403ef/jeta/maml.py#L14-L15

The above lines cause an error while running MAML using OptiTrainer


OptiTrainer calls the function maml_adapt and it's parameters are passed in OptiTrainer itself.

Note: maml_adapt is passed as a parameter to OptiTrainer as adapt_fn https://github.com/SforAiDl/jeta/blob/e55bf83c9f89c662872f50a1ef8885c1085403ef/jeta/opti_trainer.py#L145

An error is being thrown because maml_adapt expects the above 2 arguments which are not being passed.


fas can be included into OptiTrainer as this parameter is common across many Optimization based approaches, but maml_lr can't be included as it is specific to only MAML. Hence a different approach must be used to allow it as a parameter.

veds12 commented 2 years ago

Looks like a bug introduced in #39. We can add kwargs in OptiTrainer. That will solve the problem of passing maml_lr.