unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
8.05k stars 878 forks source link

Make models wandb compatible #943

Closed timmermansjoy closed 2 years ago

timmermansjoy commented 2 years ago

Is your feature request related to a current problem? Please describe. A clear and concise description of what the problem is.

When using weights and biases it is not possible to log the topology of a model

wandb_logger = WandbLogger( ... )
model_nbeats = NBEATSModel( ... )
wandb_logger.watch(model_nbeats)

Here we lose out on logging the parameters into the WandB

Describe proposed solution

I have digged into the code a little and I dont see a clear solution yet. WandB does a check if the model is a torch.nn model. I did disable this and then saw that the model needed named_parameters. I did not dig any further then that. But I suggest this can be added to the model and then contact WandB to add a flag for darts.forecasting model Screenshot 2022-05-04 at 12 44 41 PM

An alternative solution to adding these parameters to WandB would also be sufficient

dennisbader commented 2 years ago

Our TorchForecastingModels are built on top PyTorch Lightning which supports the WandB logger: https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.loggers.wandb.html#weights-and-biases-logger

Can you try passing the logger at model creation with pl_trainer_kwargs={"logger": my_wandb_logger}?

timmermansjoy commented 2 years ago

yes thank you. I was stupid and did not see all the results in wandb. My apologies for being stupid 😅