jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.87k stars 611 forks source link

Nbeats forward() method needs and returns a dict(). That goes against PL ability to log the computation graph #529

Closed pnmartinez closed 3 years ago

pnmartinez commented 3 years ago

Expected behavior

Trying to log the computation graph to TensorBoardLogger class, which is done by passing log_graph = True.

Actual behavior

Error: the TensorBoardLogger needs the model (Nbeats in my case) to have defined a self.example_input_array. I did that (easy peasy) but more convoluted errors arise, all the way down to Pytorch API.

It seems that the .forward() method returning a dict() (with keys "prediction", "target_scale"... the expected format on pytorch_forecasting), can only do the graph properly if returning a Tensor.

Pytorch also recommends (see here, or here) the forward method to accept and return Tensors, or tuples of them.

I know this is a major modification on the lib intermediate objects, and maybe there are other ways to get the computation graph, but I think it would be a good change for the long term.

Cheers!

jdb78 commented 3 years ago

This is being fixed in the upcoming release. It's actually sufficient to use a namedtuple

pnmartinez commented 3 years ago

Nice!

This will be awesome in the case of custom models.


From: Jan Beitner @.> Sent: Thursday, June 3, 2021 9:30:23 PM To: jdb78/pytorch-forecasting @.> Cc: Pablo @.>; Author @.> Subject: Re: [jdb78/pytorch-forecasting] Nbeats forward() method needs and returns a dict(). That goes against PL ability to log the computation graph (#529)

This is being fixed in the upcoming release. It's actually sufficient to use a namedtuple

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHubhttps://github.com/jdb78/pytorch-forecasting/issues/529#issuecomment-854122106, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AHEB2LZ3ET53NOFVCL2ZHC3TQ7J47ANCNFSM45W7W76Q.

pnmartinez commented 3 years ago

Hi @jdb78

So, in the new version, is the log_graph = True flag supposed to be working on the TensorBoardLogger?

I am testing in a fresh environment, and can't seem to log it properly.