bayesflow-org / bayesflow

A Python library for amortized Bayesian workflows using generative neural networks.
https://bayesflow.org/
MIT License
377 stars 52 forks source link

ContinuousApproximator.sample() fails without previous adapter calls (e.g., when loading data) #255

Open elseml opened 1 day ago

elseml commented 1 day ago

I noticed that after switching from generating bf.datasets on-the-fly to loading pre-simulated data, ContinuousApproximator.sample() fails since the adapter is not called before sampling anymore. Concretely, in line 141 of continuous_approximator.py, the adapter is called with strict=False to process the observed data (and not require parameter keys while doing so):

conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs) 

This raises the following error in the adapters forward() method when working with loaded data:

"ValueError: Cannot call `forward` with `strict=False` before calling `forward` with `strict=True`.". 

The error is easily fixed by manually calling the adapter on the data before sampling, but of course unexpected for the user and should therefore be handled internally. @LarsKue @stefanradev93: what do you think would be a principled handling of this behavior?

paul-buerkner commented 1 day ago

Based on how I understand what you are doing, I agree with you that this should be differently handled. Just to make sure I understand you correctly, could you add a small example here that (only) includes the relevant code parts?

elseml commented 1 day ago

I looked further into the issue, as far as I can see it is caused by the OfflineDataset and approximator no longer referring to the same adapter object in memory:

Here is some reduced pseudocode to keep things concise:

Simulating at the beginning does not fail:

adapter = Adapter()
data = OfflineDataset(simulate(), adapter)
approximator = ContinuousApproximator(summary_net, inference_net, adapter)
approximator.fit(data)
approximator.sample(data)

When the data is loaded from an external source (where the adapter was also supplied to OfflineDataset), sampling fails:

adapter = Adapter()
data = load_data(path)
approximator = ContinuousApproximator(summary_net, inference_net, adapter)
approximator.fit(data)
approximator.sample(data)

Calling the adapter manually before sampling fixes the error:

adapter = Adapter()
data = load_data(path)
approximator = ContinuousApproximator(summary_net, inference_net, adapter)
approximator.fit(data)
_ = adapter(data)
approximator.sample(data)

Creating data manually before sampling does not fix it (i.e., simply creating an OfflineDataset) since the adapter is not called during OfflineDataset construction:

adapter = Adapter()
data = load_data(path)
approximator = ContinuousApproximator(summary_net, inference_net, adapter)
approximator.fit(data)
data_2 = OfflineDataset(simulate(), adapter)
approximator.sample(data_2)
paul-buerkner commented 1 day ago

Thank you! This is very helpful! @LarsKue and @stefanradev93 what are your takes on how to fix this?

elseml commented 1 day ago

Indeed, when passing OfflineDataset.adapter to the approximator, the error is gone (so it is not really a bug but more of an unexpected behavior). But this is a rather unintuitive solution for users that should not be required.

data = load_data(path)
approximator = ContinuousApproximator(summary_net, inference_net, data.adapter)
approximator.fit(data)
approximator.sample(data)
paul-buerkner commented 1 day ago

It will appear to users as a bug because it should just work. In any case, we should fix it before 2.0 release.

LarsKue commented 1 day ago

Could be faulty serialization in the Adapter. I will investigate next week.