pymc-devs / pymc-examples

Examples of PyMC models, including a library of Jupyter notebooks.
https://www.pymc.io/projects/examples/en/latest/
MIT License
263 stars 218 forks source link

Bayesian neural network notebook does not work #504

Closed earlbellinger closed 1 year ago

earlbellinger commented 1 year ago

Notebook title: Variational Inference: Bayesian Neural Networks Notebook url: https://github.com/pymc-devs/pymc-examples/blob/main/examples/variational_inference/bayesian_neural_network_advi.ipynb

Issue description

Cell 14:

    pm.set_data(new_data={"ann_input": grid_2d, "ann_output": dummy_out})
    ppc = pm.sample_posterior_predictive(trace)

yields

ValueError: size does not match the broadcast shape of the parameters. (100,), (100,), (10000,)
Apply node that caused the error: bernoulli_rv{0, (0,), int64, True}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F779E3502E0>), MakeVector{dtype='int64'}.0, TensorConstant{4}, Elemwise{Sigmoid}[(0, 0)].0)
Toposort index: 10
Inputs types: [RandomGeneratorType, TensorType(int64, (1,)), TensorType(int64, ()), TensorType(float64, (?,))]
Inputs shapes: ['No shapes', (1,), (), (10000,)]
Inputs strides: ['No strides', (8,), (), (8,)]
Inputs values: [Generator(PCG64) at 0x7F779E3502E0, array([100]), array(4), 'not shown']
Outputs clients: [['output'], ['output']]

Expected output

 100.00% [5000/5000 00:30<00:00]
OriolAbril commented 1 year ago

Does it work if you do:

ann_input = pm.Data("ann_input", X_train, mutable=True, dims="obs_id")
ann_output = pm.Data("ann_output", Y_train, mutable=True, dims="obs_id")
...
# Binary classification -> Bernoulli likelihood
out = pm.Bernoulli(
    "out",
    act_out,
    observed=ann_output,
    total_size=Y_train.shape[0],  # IMPORTANT for minibatches
    dims="obs_id",
)

when defining the model?

If not I think the fix will need model.set_dim.

earlbellinger commented 1 year ago

Thanks for your reply. Unfortunately not; this yields:

ShapeError: Length of `dims` must match the dimensions of the dataset. (actual 1 != expected 2) 
OriolAbril commented 1 year ago

oh, sorry, the X are two dimensional, forgot about that. Updated:

ann_input = pm.Data("ann_input", X_train, mutable=True, dims=("obs_id", "train_cols"))
earlbellinger commented 1 year ago

That fixed it! Thanks!

OriolAbril commented 1 year ago

Do you want to rerun the notebook and send a PR to fix the website?

earlbellinger commented 1 year ago

Sure thing, I made one here: https://github.com/pymc-devs/pymc-examples/pull/506