brian-team / brian2modelfitting

Model fitting toolbox for the Brian 2 simulator
https://brian2modelfitting.readthedocs.io
Other
14 stars 6 forks source link

[MRG] Add a refine function that uses lmfit.minimize #28

Closed mstimberg closed 4 years ago

mstimberg commented 4 years ago

[Not quite ready to merge yet because it is missing error handling, documentation, and tests, but please try it out in the refine_fits branch]

First version of a refine function in TraceFitter that can be run after fitting with the usual methods. This needs the lmfit package and will run a least squares optimization with Levenberg-Marquardt by default. In its most simple form, you can use:

fitter.fit(...)
params, result = fitter.refine()

This will re-use the bounds set previously, and start from the best parameters found at the end of the fit. Alternatively, you can set the parameters you want to start from in a dictionary:

params, result = fitter.refine({'g_L': 10*nS, ...})

and/or change the bounds as keyword arguments (like in the fit function):

params, result = fitter.refine(g_L=[1*nS, 100*nS], ...)

Any additional keyword arguments will be passed on to lmfit.minimize, which you can use to set the method or its parameters.

The refine function returns two values, the first is a dictionary with the best parameters (without units at the moment, for consistency with fit), and the second is a lmfit.MinimizerResult, which stores all kind of additional information. You can use print(lmfit.fit_report(result)) to print this information.

mstimberg commented 4 years ago

Note that this does not work with standalone mode yet.

mstimberg commented 4 years ago

I think this is "good enough" for now. Its main downside is that it only works for TraceFitter and it hardcodes the least-squares fit of the traces without any weighing possible (i.e. the metric is ignored). This is due to the way lmfit works. Would be nice to work around it, but I think it is already useful in its current version.

romainbrette commented 4 years ago

One issue I noticed: the values of the params dictionary are arrays instead of floats (as given by the fit method). I think it would also be valuable to have some form of progress reporting.

romainbrette commented 4 years ago

It runs on my Paramecium script, I have to check the results in more detail. I did have to use the maxfev keyword otherwise it would not end (or I suppose after a long time).

romainbrette commented 4 years ago

I'm assuming result.chisqr corresponds to the mean squared error, sounds right?

romainbrette commented 4 years ago

I seem to get something a bit different though.

romainbrette commented 4 years ago

How can I get the error from the fitter if I have parameter values, to check? (calc_errors?)

mstimberg commented 4 years ago

One issue I noticed: the values of the params dictionary are arrays instead of floats (as given by the fit method).

These are scalar arrays so you can use them exactly as floats. In general I prefer them because you can e.g. check their shape as for arrays. But if you don't like the output where it says "array(...)" we can always change it.

romainbrette commented 4 years ago

I got an error because I serialized with JSON and it doesn't like arrays apparently.

romainbrette commented 4 years ago

By the way does it take into account t_start?

romainbrette commented 4 years ago

It seems not, that could explain the discrepancy.

mstimberg commented 4 years ago

I got an error because I serialized with JSON and it doesn't like arrays apparently.

Ugh, ok. I see you made the change already, I'll take a note of it for the future.

mstimberg commented 4 years ago

By the way does it take into account t_start?

Indeed it does not!

romainbrette commented 4 years ago

Can I let you do this change?

mstimberg commented 4 years ago

By the way does it take into account t_start?

Indeed it does not!

Conceptually this is actually a bit tricky, since the refine method is independent of the chosen metric, and t_start is a property of the metric. But I think we should go for a "practicality beats purity" and re-use t_start from the metric.

mstimberg commented 4 years ago

Can I let you do this change?

Sure.

mstimberg commented 4 years ago

I added the t_start option to TraceFitter.refine. If not specified, it will reuse what the previous fit call used via its metric.

mstimberg commented 4 years ago

How can I get the error from the fitter if I have parameter values, to check? (calc_errors?)

Forgot to answer this one. Something like the following should work:

traces = fitter.generate_traces(params={...})
error = metric.calc(np.array(traces[None, :, :]), fitter.output, dt)

The [None, :, :] is necessary because in general the metric works in parallel on multiple parameter combinations, but in this case you only have a single one. The np.array is just to get rid of the units.

mstimberg commented 4 years ago

I'm assuming result.chisqr corresponds to the mean squared error, sounds right?

result.chisqr is the summed, not the mean, squared error (see the documentation), so you'll have to divide it by the number of time steps to get the MSE. There's another catch: lmfit does not work with more than once trace, all the traces are concatenated into a single flat array. There's therefore also no averaging over the traces, if you have more than one you have to divide by that number as well.

I did this on a simple example and the numbers match.

romainbrette commented 4 years ago

Yes, it seems to work. I use this instead for the error: mean(result.residual**2)**.5

romainbrette commented 4 years ago

There's a normalization argument to MSEmetric that is ignored (no big deal).

romainbrette commented 4 years ago

I'm running it to see whether it makes improvements (I think it does).

mstimberg commented 4 years ago

There's a normalization argument to MSEmetric that is ignored (no big deal).

This is again the conceptual problem we had earlier. We do not use the metric in refine, so it is unclear of whether we should use previously set arguments to metrics. For t_start I think we won't have any problems, but if we just check whether the previously used metric has a normalization attribute, we might get into issues when another metric defines it with a different meaning. We could maybe special-case MSEMetric as it is the equivalent to what we are doing here.

mstimberg commented 4 years ago

but if we just check whether the previously used metric has a normalization attribute, we might get into issues when another metric defines it with a different meaning. We could maybe special-case MSEMetric as it is the equivalent to what we are doing here.

We could also make it "official", and give a normalization attribute to all metrics (at least all TraceMetric) instead of just for MSEMetric.

romainbrette commented 4 years ago

Ah yes maybe.

mstimberg commented 4 years ago

Yes, it seems to work. I use this instead for the error: mean(result.residual**2)**.5

I'm sure you noticed but just in case: this is the square root of what MSEMetric calculates (i.e. RMSE instead of MSE).

romainbrette commented 4 years ago

Yes sure.

romainbrette commented 4 years ago

So far it seems to work better than differential evolution. I'll try a mix.

romainbrette commented 4 years ago

Strange. I do 1000 rounds of DE, then refine, but the error starts from well above the error obtained after DE. However, if I do just 10 rounds of DE, then I don't get that problem. I don't know what's going on!

romainbrette commented 4 years ago

Could be my mistake, I'm doing it again...

romainbrette commented 4 years ago

Ok it was my mistake! It works and it does seem to improve.

mstimberg commented 4 years ago

I moved the normalization argument into the general Metric class, it now gets respected by the refine method as well. I also added a callback option which defaults to text (as for fit), and reuses the same callback methods.

I guess this is enough for this PR to merge it?