pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.54k stars 986 forks source link

[feature request] Implement Exponential Smoothing RNN #1287

Closed jperl closed 5 years ago

jperl commented 6 years ago

Issue Description

I am very excited about using @slaweku's ES-RNN and want to know if I can help with any parts of the re-implementation?

fritzo commented 6 years ago

Sure, @slaweku has already ported his implementation to PyTorch, and we are working on batching before integrating into Pyro. We'll be sure to cc you on the first PR (that sets up structure), then we could use help with subsequent PRs like tutorials, batching, JIT support etc.

jperl commented 6 years ago

Awesome thanks! Is there any way to view the PyTorch implementation so I can get up to speed?

JoshuaC3 commented 6 years ago

@fritzo any advances on this? kind regards :)

pjgaudre commented 6 years ago

@fritzo, I'm excited to use his model for my use case as well. Please keep me updated.

panaali commented 6 years ago

So excited to see this!

andrewcz commented 6 years ago

would be excited to :)

ngoodman commented 6 years ago

there's lots of interest in this, so we should make it happen... @slaweku would it be reasonable to put code so far in this / another repo so people can pitch in?

slaweku commented 6 years ago

Hi, I did port the M4 code to Python and Pytorch, but experienced some technical issues (in Python multiprocessing) that make accuracy lower in comparison to the original Dynet/C++. Also the current code is written as a tool to forecast M4 competition, not a well structured library. Recently, I have started working with Neeraj, and it looks promising, I hope we can overcome all the issues.

Regards, Slawek

vshulyak commented 6 years ago

@slaweku Any chance you could open source what you have? I have some extra time, so I could help with debugging.

I replicated some parts of ES-RNN myself, but found that either the architecture or my implementation not really performing well on my dataset (not M4). Really want to figure out why.

neerajprad commented 5 years ago

After discussing with everyone, we came to the conclusion that there is no clear path towards integrating ESRNN into Pyro at this stage. I am closing this issue as this is not being actively worked on by Pyro developers.

Given the interest from the community however, @slaweku in interested in open sourcing his implementation in PyTorch as a separate library, after existing issues have been fixed. There is also a fair amount of work that needs to be done to provide a more general interface. This will be a separate open source effort, with its own home. For the time being, it is best to refer to his original C++/Dynet implementation which is available here.

aredd-cmu commented 5 years ago

@neerajprad @slaweku I'm working on a term project where we would like to implement the algorithm in pytorch. Can you offer a few more details on the issues you've seen?

slaweku commented 5 years ago

Hi, Accuracy of trained models is worse than original ones in C++/Dynet code. In M4 Competition I have two types of models: ones that used a simple ensemble and "ensemble of specialists" where I train concurrently a number (5-7) RNNs and they are forced to specialize in a subset of data. This second approach works usually better and was the first I converted from C++. However, due to slowness of Python, I had to implement it with multiprocessing, but that causes another issue, I do not know how to deal with, inability to have one trainer that updates all per series parameters

Slawek

On Thu, Nov 29, 2018 at 5:42 AM aredd-cmu notifications@github.com wrote:

@neerajprad https://github.com/neerajprad @slaweku https://github.com/slaweku I'm working on a project for term project where we would like to implement the algorithm in pytorch. Can you offer a few more details on the issues you've seen?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/uber/pyro/issues/1287#issuecomment-442837451, or mute the thread https://github.com/notifications/unsubscribe-auth/AYIz0lMObClbik4Ly6BVJFwm-bLIOmNsks5uz-QqgaJpZM4VvBX0 .

aredd-cmu commented 5 years ago

Thanks @slaweku! Does your python implementation mirror the Dynet implementation? Based on my analysis of your code, this is being run on the CPU, one-series-at-a-time? Based on your experience is batching and running on a GPU possible?

slaweku commented 5 years ago

Hi, Yes, one series at a time on CPU. Batching is possible, but my attempts to use the clever "auto batching" of Dynet was detrimental to the accuracy and not really improving the speed. In pytorch you need to take care yourself of creating a proper batch, and it is certainly possibly, but code is complicated, as you need to deal with updates of seasonality components and levels - doing it on a bunch of series at the same time is not straightforward, e.g. they have different length. Doing separately will most likely defeat the potential speed benefits.

On Thu, Nov 29, 2018 at 10:44 AM aredd-cmu notifications@github.com wrote:

Thanks @slaweku https://github.com/slaweku! Does your python implementation mirror the Dynet implementation? Based on my analysis of your code, this is being run on the CPU, one-series-at-a-time? Based on your experience is batching and running on a GPU possible?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/uber/pyro/issues/1287#issuecomment-442946859, or mute the thread https://github.com/notifications/unsubscribe-auth/AYIz0gNyOqIeKN5G_LnL_KC1vQptxpxaks5u0CsHgaJpZM4VvBX0 .

JHBalaji commented 5 years ago

@slaweku Is this issue still open?

slaweku commented 5 years ago

Hi, From what I remember, try to redefine the squash function, line 254, into:

Expression squash(const Expression& x) {

return log(x);

}

Regards,

Slawek

On Thu, Apr 25, 2019 at 1:31 PM andmib notifications@github.com wrote:

@neerajprad https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_neerajprad&d=DwMFaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=OLuV2c5T0ZZuS1P1pwcZxFjQ1ZfxvCxlZz5CB6zYb-0&s=UPywannoKfmCKdpmm6YPZpSzR0R9FTAHjroNZaN-MqY&e= @slaweku https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_slaweku&d=DwMFaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=OLuV2c5T0ZZuS1P1pwcZxFjQ1ZfxvCxlZz5CB6zYb-0&s=x-lTEfoed9P0o5PePpnDliTn6MK4AsvGEHRBn5fjyPU&e= I appologize if this is not the proper place, but if we have basic questions on compilation of the M4 code, where would be the most appropriate place for that? I.E. if I'm getting the error:

ES_RNN_E_PI.cc: In function ‘int main(int, char**)’: ES_RNN_E_PI.cc:970:58: error: cannot bind non-const lvalue reference of type ‘dynet::Expression&’ to an rvalue of type ‘dynet::Expression’ joinedInput_ex.emplace_back(noise(squash(cdiv(input1_ex, levels_exVect[i])), NOISE_STD)); //input normalization+noise



—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_pyro-2Dppl_pyro_issues_1287-23issuecomment-2D486826960&d=DwMFaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=OLuV2c5T0ZZuS1P1pwcZxFjQ1ZfxvCxlZz5CB6zYb-0&s=vOrtC1N1I2LQU9quhB4uFh5PNc6SzMOq0fvDpppWl1c&e=>,
or mute the thread
<https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_notifications_unsubscribe-2Dauth_AGBDHUUA3CT6C4YJHCEDG6DPSIIKJANCNFSM4FN4CX2A&d=DwMFaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=OLuV2c5T0ZZuS1P1pwcZxFjQ1ZfxvCxlZz5CB6zYb-0&s=-IPOu3MUJBdwJ-Q4NHrUB8aMDQHB8wnTqI7xqnqwCCA&e=>
.
andmib commented 5 years ago

@slaweku Thank you for this. I deleted my original post because I didn't want to bother you with minutia, but just noticed you answered. Your suggestion solved that error!

andmib commented 5 years ago

@slaweku Did you ever run into a "magnitude of gradient is bad: -nan" issue when running the ES_RNN module?

slaweku commented 5 years ago

Hi, I did not, but I heard of someone having it. The issue was that at one case, the data record ended with some weird combination of \n \r and some versions of C++ libraries read next record as an empty one, so zeros. And then you do log(0). The solution was to replace in M4TS(string category, stringstream &line_stream)

if (c != '\"') {//remove quotes

with

if (c != '\"' && c!='\r') //remove quotes and very occasional double end of line

Regards,

Slawek

On Thu, May 2, 2019 at 12:01 PM andmib notifications@github.com wrote:

@slaweku https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_slaweku&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=dmh5mG7VMh64j9-tm7DQ82k_GU2Q8lBfB0iaynA7udM&s=_e7vePWb9JFEu1vlriwXHhOcDPNNI7AySaspUbwQEG8&e= Did you ever run into a "magnitude of gradient is bad: -nan" issue when running the ES_RNN module?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_pyro-2Dppl_pyro_issues_1287-23issuecomment-2D488792062&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=dmh5mG7VMh64j9-tm7DQ82k_GU2Q8lBfB0iaynA7udM&s=8LNhMmGBvCSfo5UzEU0pziQwzOxYB17ocHjRhjeBgss&e=, or mute the thread https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_notifications_unsubscribe-2Dauth_AGBDHUWJHXJSCU6W57BINWLPTM3BDANCNFSM4FN4CX2A&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=dmh5mG7VMh64j9-tm7DQ82k_GU2Q8lBfB0iaynA7udM&s=IeywfLiewJ03iGnkuU_2A7XkboS_abr7bFjp1RK0zpo&e= .

slaweku commented 5 years ago

I will try to update the competition github.

Regards, Slawek

On Thu, May 2, 2019 at 12:11 PM Slawek Smyl slawek@uber.com wrote:

Hi, I did not, but I heard of someone having it. The issue was that at one case, the data record ended with some weird combination of \n \r and some versions of C++ libraries read next record as an empty one, so zeros. And then you do log(0). The solution was to replace in M4TS(string category, stringstream &line_stream)

if (c != '\"') {//remove quotes

with

if (c != '\"' && c!='\r') //remove quotes and very occasional double end of line

Regards,

Slawek

On Thu, May 2, 2019 at 12:01 PM andmib notifications@github.com wrote:

@slaweku https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_slaweku&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=dmh5mG7VMh64j9-tm7DQ82k_GU2Q8lBfB0iaynA7udM&s=_e7vePWb9JFEu1vlriwXHhOcDPNNI7AySaspUbwQEG8&e= Did you ever run into a "magnitude of gradient is bad: -nan" issue when running the ES_RNN module?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_pyro-2Dppl_pyro_issues_1287-23issuecomment-2D488792062&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=dmh5mG7VMh64j9-tm7DQ82k_GU2Q8lBfB0iaynA7udM&s=8LNhMmGBvCSfo5UzEU0pziQwzOxYB17ocHjRhjeBgss&e=, or mute the thread https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_notifications_unsubscribe-2Dauth_AGBDHUWJHXJSCU6W57BINWLPTM3BDANCNFSM4FN4CX2A&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=dmh5mG7VMh64j9-tm7DQ82k_GU2Q8lBfB0iaynA7udM&s=IeywfLiewJ03iGnkuU_2A7XkboS_abr7bFjp1RK0zpo&e= .

andmib commented 5 years ago

@slaweku Thanks a ton - I'll give that a shot. I'm first trying to get it running on the M4 competition ts set, then I'll likely move on to try and modify it so it runs on a different set of series (admittedly a much more sparse set of data, will be curious to see how it performs).

cpoptic commented 5 years ago

Is there an actual branch with @slaweku's ES-RNN code in Python/PyTorch?

I'd be willing to help contribute, but cannot find any relevant repo where this Python port is being stored. @slaweku @neerajprad @aredd-cmu do we have a dedicated repo for collaborating on the ES-RNN code? Thanks

slaweku commented 5 years ago

Hi, No, there is not. I have a Python/Pytorch code that forecasts some Uber data, but it evolved a bit from its M4 origins. I promise myself to cleanup the code, re-attach to M4 data set and open source it, but mostly because I do not think too highly of Python, there are always more urgent things to do :-)

Regards, Slawek

On Fri, May 3, 2019 at 7:27 AM Jonathan Denim notifications@github.com wrote:

Is there an actual branch with @slaweku https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_slaweku&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=WR-CLTbho0NTZBZtiWAEVmK1F_ewKDqQ6mA7xRBSvqg&s=U0mdTAh38pwuC4FgJ5duiLgHulMOrxMB4mePNN02Z14&e='s ES-RNN code in Python/PyTorch?

I'd be willing to help contribute, but cannot find any relevant repo where this Python port is being stored. @slaweku https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_slaweku&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=WR-CLTbho0NTZBZtiWAEVmK1F_ewKDqQ6mA7xRBSvqg&s=U0mdTAh38pwuC4FgJ5duiLgHulMOrxMB4mePNN02Z14&e= @neerajprad https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_neerajprad&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=WR-CLTbho0NTZBZtiWAEVmK1F_ewKDqQ6mA7xRBSvqg&s=TMkEhWTC1vDQ9QSDQf3ItoFuXFBB_TvRDE8EG4reFhE&e= @aredd-cmu https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_aredd-2Dcmu&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=WR-CLTbho0NTZBZtiWAEVmK1F_ewKDqQ6mA7xRBSvqg&s=oYRurqOsABIu9ydWjA83_g60pAhhpBSLb28En6rDxHQ&e= do we have a dedicated repo for collaborating on the ES-RNN code? Thanks

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_pyro-2Dppl_pyro_issues_1287-23issuecomment-2D489114188&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=WR-CLTbho0NTZBZtiWAEVmK1F_ewKDqQ6mA7xRBSvqg&s=cteKTtRAJJSsftfereLCyPx-QuCEdiMoy151wG4ayFE&e=, or mute the thread https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_notifications_unsubscribe-2Dauth_AGBDHUUFO52OCT6UTPO5SC3PTRDULANCNFSM4FN4CX2A&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=WR-CLTbho0NTZBZtiWAEVmK1F_ewKDqQ6mA7xRBSvqg&s=-1FpvsW10zKd1U2hPjIqSEtsBZvBIN-G9raI2xlf3EY&e= .

andmib commented 5 years ago

@slaweku I'd love to steal some of your time offline to chat about your forecast applicability to some of the problems I face, particularly with intermittent hierarchical time series

damitkwr commented 5 years ago

We have a PyTorch version partially implemented at: https://github.com/damitkwr/ESRNN-GPU and tested it on the M4 Competition Dataset. We mostly replicate the results from slawek's original implementation and do a little better on some intervals, they are detailed in the paper on the repo.

slaweku commented 5 years ago

Thank you guys, I will look at it.

Slawek

On Wed, Jun 19, 2019 at 12:44 AM Kaung M. Khin notifications@github.com wrote:

We have a PyTorch version partially implemented at: https://github.com/damitkwr/ESRNN-GPU https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_damitkwr_ESRNN-2DGPU&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=anciOfebrF_e0ZrWj-lQItmKywpaFl6LS2pfLT0Yx08&s=o390Q-KAeT3JvsC8rNdJdxx915SWC7IYvp-XCyXa_04&e=

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_pyro-2Dppl_pyro_issues_1287-3Femail-5Fsource-3Dnotifications-26email-5Ftoken-3DAGBDHUUOZOCYV3RSGTUFXHTP3FJNLA5CNFSM4FN4CX2KYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODYAB2OQ-23issuecomment-2D503323962&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=anciOfebrF_e0ZrWj-lQItmKywpaFl6LS2pfLT0Yx08&s=MKrscVknck5pX_V9vsIRQnoaNkraV7luuwf81MyIAkQ&e=, or mute the thread https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_notifications_unsubscribe-2Dauth_AGBDHUXVESKI7PGZYUAR5N3P3FJNLANCNFSM4FN4CX2A&d=DwMCaQ&c=r2dcLCtU9q6n0vrtnDw9vg&r=zxuTpd98KIE7ZkteWPOCXA&m=anciOfebrF_e0ZrWj-lQItmKywpaFl6LS2pfLT0Yx08&s=tUU98wx2RPxtTnnXqG1oKxcI1fcSAZ6_u9wf-1_UR2c&e= .

ebonetti commented 5 years ago

@slaweku @damitkwr I find your work invaluable for figuring out neural nets applied to time series: thanks a lot for open-sourcing it 👍