dstl / Stone-Soup

A software project to provide the target tracking community with a framework for the development and testing of tracking algorithms.
https://stonesoup.rtfd.io
MIT License
403 stars 134 forks source link

Ensemble Predictor does not call vectorized models correctly #668

Closed 0sm1um closed 2 years ago

0sm1um commented 2 years ago

The code in Pull Request #644 introduced a bug to the EnsemblePredictor which causes the Ensemble Kalman Filter to quickly diverge/fail.

Upon calling the model, the transition model essentially populates the prediction ensemble with a StateVectors instance of N identical vectors. In other words an nxN StateVectors matrix of identical column vectors.

In Pull Request #644, the EnsembleState and EnsemblePredictor were changed to utilize the vectorized models which were introduced to speed up the Particle Filter in PR #365.

0sm1um commented 2 years ago

The fix for this is already implemented in PR #669

sdhiscocks commented 2 years ago

Thanks @0sm1um. Looks like more fundamental issue of calls to the rvs method not setting the sample size. Luckily this seems to only effect particle filter (which has same num_samples fix you used); and of course the Ensemble filter.

Here's a proposed alternative fix, which deals with the issue more generally.

diff --git a/stonesoup/models/base.py b/stonesoup/models/base.py
index b4b1bd8c..27bc5471 100644
--- a/stonesoup/models/base.py
+++ b/stonesoup/models/base.py
@@ -134,7 +134,7 @@ class LinearModel(Model):
         """
         if isinstance(noise, bool) or noise is None:
             if noise:
-                noise = self.rvs(**kwargs)
+                noise = self.rvs(num_samples=state.state_vector.shape[1], **kwargs)
             else:
                 noise = 0

diff --git a/stonesoup/models/measurement/linear.py b/stonesoup/models/measurement/linear.py
index d03e195a..3fe32d8c 100644
--- a/stonesoup/models/measurement/linear.py
+++ b/stonesoup/models/measurement/linear.py
@@ -75,7 +75,7 @@ class LinearGaussian(MeasurementModel, LinearModel, GaussianModel):

         if isinstance(noise, bool) or noise is None:
             if noise:
-                noise = self.rvs()
+                noise = self.rvs(num_samples=state.state_vector.shape[1], **kwargs)
             else:
                 noise = 0

diff --git a/stonesoup/models/measurement/nonlinear.py b/stonesoup/models/measurement/nonlinear.py
index 64ccb5ad..054a33e4 100644
--- a/stonesoup/models/measurement/nonlinear.py
+++ b/stonesoup/models/measurement/nonlinear.py
@@ -255,7 +255,7 @@ class CartesianToElevationBearingRange(NonLinearGaussianMeasurement, ReversibleM

         if isinstance(noise, bool) or noise is None:
             if noise:
-                noise = self.rvs()
+                noise = self.rvs(num_samples=state.state_vector.shape[1], **kwargs)
             else:
                 noise = 0

@@ -412,7 +412,7 @@ class CartesianToBearingRange(NonLinearGaussianMeasurement, ReversibleModel):

         if isinstance(noise, bool) or noise is None:
             if noise:
-                noise = self.rvs()
+                noise = self.rvs(num_samples=state.state_vector.shape[1], **kwargs)
             else:
                 noise = 0

@@ -538,7 +538,7 @@ class CartesianToElevationBearing(NonLinearGaussianMeasurement):

         if isinstance(noise, bool) or noise is None:
             if noise:
-                noise = self.rvs()
+                noise = self.rvs(num_samples=state.state_vector.shape[1], **kwargs)
             else:
                 noise = 0

@@ -633,7 +633,7 @@ class Cartesian2DToBearing(NonLinearGaussianMeasurement):

         if isinstance(noise, bool) or noise is None:
             if noise:
-                noise = self.rvs()
+                noise = self.rvs(num_samples=state.state_vector.shape[1], **kwargs)
             else:
                 noise = 0

@@ -774,7 +774,7 @@ class CartesianToBearingRangeRate(NonLinearGaussianMeasurement):

         if isinstance(noise, bool) or noise is None:
             if noise:
-                noise = self.rvs()
+                noise = self.rvs(num_samples=state.state_vector.shape[1], **kwargs)
             else:
                 noise = 0

@@ -922,7 +922,7 @@ class CartesianToElevationBearingRangeRate(NonLinearGaussianMeasurement, Reversi

         if isinstance(noise, bool) or noise is None:
             if noise:
-                noise = self.rvs()
+                noise = self.rvs(num_samples=state.state_vector.shape[1], **kwargs)
             else:
                 noise = 0

diff --git a/stonesoup/models/transition/nonlinear.py b/stonesoup/models/transition/nonlinear.py
index b294238f..37de2979 100644
--- a/stonesoup/models/transition/nonlinear.py
+++ b/stonesoup/models/transition/nonlinear.py
@@ -114,7 +114,7 @@ class ConstantTurn(GaussianTransitionModel, TimeVariantModel):
              turn_rate])
         if isinstance(noise, bool) or noise is None:
             if noise:
-                noise = self.rvs(**kwargs)
+                noise = self.rvs(num_samples=state.state_vector.shape[1], **kwargs)
             else:
                 noise = 0
         return sv2 + noise
@@ -186,7 +186,7 @@ class ConstantTurnSandwich(ConstantTurn):
         sv_out = StateVectors(np.concatenate(sv_list))
         if isinstance(noise, bool) or noise is None:
             if noise:
-                noise = self.rvs(**kwargs)
+                noise = self.rvs(num_samples=state.state_vector.shape[1], **kwargs)
             else:
                 noise = 0
         return sv_out + noise
diff --git a/stonesoup/predictor/particle.py b/stonesoup/predictor/particle.py
index 8f65b40d..7370aaa3 100644
--- a/stonesoup/predictor/particle.py
+++ b/stonesoup/predictor/particle.py
@@ -42,7 +42,6 @@ class ParticlePredictor(Predictor):
             prior,
             noise=True,
             time_interval=time_interval,
-            num_samples=len(prior),
             **kwargs)

         return Prediction.from_state(prior, state_vector=new_state_vector, weight=prior.weight,
diff --git a/stonesoup/updater/particle.py b/stonesoup/updater/particle.py
index b8f3f814..89fc2991 100644
--- a/stonesoup/updater/particle.py
+++ b/stonesoup/updater/particle.py
@@ -51,8 +51,7 @@ class ParticleUpdater(Updater):
             measurement_model = hypothesis.measurement.measurement_model

         predicted_state.weight = predicted_state.weight * measurement_model.pdf(
-            hypothesis.measurement, predicted_state, num_samples=len(predicted_state),
-            **kwargs)
+            hypothesis.measurement, predicted_state, **kwargs)

         # Normalise the weights
         sum_w = np.array(Probability.sum(predicted_state.weight))