Closed pnmartinez closed 3 years ago
This is being fixed in the upcoming release. It's actually sufficient to use a namedtuple
Nice!
This will be awesome in the case of custom models.
From: Jan Beitner @.> Sent: Thursday, June 3, 2021 9:30:23 PM To: jdb78/pytorch-forecasting @.> Cc: Pablo @.>; Author @.> Subject: Re: [jdb78/pytorch-forecasting] Nbeats forward() method needs and returns a dict(). That goes against PL ability to log the computation graph (#529)
This is being fixed in the upcoming release. It's actually sufficient to use a namedtuple
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHubhttps://github.com/jdb78/pytorch-forecasting/issues/529#issuecomment-854122106, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AHEB2LZ3ET53NOFVCL2ZHC3TQ7J47ANCNFSM45W7W76Q.
Hi @jdb78
So, in the new version, is the log_graph = True
flag supposed to be working on the TensorBoardLogger
?
I am testing in a fresh environment, and can't seem to log it properly.
Expected behavior
Trying to log the computation graph to
TensorBoardLogger
class, which is done by passinglog_graph = True
.Actual behavior
Error: the TensorBoardLogger needs the model (Nbeats in my case) to have defined a
self.example_input_array
. I did that (easy peasy) but more convoluted errors arise, all the way down to Pytorch API.It seems that the
.forward()
method returning adict()
(with keys "prediction", "target_scale"... the expected format onpytorch_forecasting
), can only do the graph properly if returning a Tensor.Pytorch also recommends (see here, or here) the forward method to accept and return Tensors, or tuples of them.
I know this is a major modification on the lib intermediate objects, and maybe there are other ways to get the computation graph, but I think it would be a good change for the long term.
Cheers!