Open arrjon opened 6 days ago
Looks like the data adapter deconfigure()
is being called in an unintended way: with tensors instead of numpy arrays.
The fix would be to find this call and convert to numpy first, e.g. with
data = keras.tree.map_structure(keras.ops.convert_to_numpy, data)
Can you please share the code that led you to this error?
Unfortunately, I cannot share a full working example now. But the error was invoked by calling the sample function after model training.
test_data = simulator.sample((10,))
conditions = {
"obs": test_data["obs"],
}
samples = approximator.sample(conditions=conditions, num_samples=100)
However, your proposed fix solves the issue for me.
I just noticed, that I cannot pass the output of the simulator test_data
directly to the sample
function, but I have to drop the entries related to the parameters first.
File ~site-packages/bayesflow/approximators/continuous_approximator.py:145, in ContinuousApproximator.sample(self, conditions, num_samples, numpy, batch_shape)
142 conditions = self.data_adapter.configure(conditions)
143 conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
144 conditions = {
--> 145 "inference_variables": self._sample(num_samples=num_samples, batch_shape=batch_shape, **conditions)
146 }
147 conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
148 conditions = self.data_adapter.deconfigure(conditions)
TypeError: ContinuousApproximator._sample() got an unexpected keyword argument 'inference_variables'
This is somewhat unintuitive and probably also an easy fix.
Bug regarding the new
dev
-branch: When using the Torch backend on Mac OS with MPS devices, a TypeError is raised during the execution of the deconfigure function indata_adapters/concatenate_keys_data_adapter.py
:Adding the following code to the deconfigure method resolves the issue for me:
While this fix works for the Torch backend on MPS devices, I am uncertain whether it is backend-agnostic enough for other libraries (e.g., JAX, TensorFlow). Any ideas how to fix this?