ucals / cvae

64 stars 10 forks source link

Error:computing log_prob at site 'y' #1

Open ly15681910119 opened 2 years ago

ly15681910119 commented 2 years ago

When the program trains the CVAE, the program reports an error:

                                                                         computing log_prob at site 'y'

The main problem is this line of code:

                                                              loss = svi.step(inputs, outputs) / inputs.size(0)
xlnn commented 1 year ago

When the program trains the CVAE, the program reports an error:

                                                                         computing log_prob at site 'y'

The main problem is this line of code:

                                                              loss = svi.step(inputs, outputs) / inputs.size(0)

me too! A same problems with you!

ucals commented 1 year ago

Hey guys... this smells like a compatibility issue. Pyro is in active development. When I created this example as a contribution to them, I used Pyro's 1.4.x; Pyro is now over 1.8. Let me investigate and come back to you.. Cheers!

xlnn commented 1 year ago

Hey guys... this smells like a compatibility issue. Pyro is in active development. When I created this example as a contribution to them, I used Pyro's 1.4.x; Pyro is now over 1.8. Let me investigate and come back to you.. Cheers!

Thank you!

shoyua commented 1 year ago

I have the same problem, is there a solution now?

shoyua commented 1 year ago

I have the same problem, is there a solution now?

set validate_args=False in pyro.sample('y', dist.Bernoulli(mask_loc, validate_args=False).to_event(1), obs=mask_ys). It seems work.