RatInABox-Lab / RatInABox

A python package for modelling locomotion in complex environments and spatially/velocity selective cell activity.
MIT License
172 stars 31 forks source link

BVC firingrates go over max_fr #110

Closed colleenjg closed 2 months ago

colleenjg commented 3 months ago

As mentioned in #108, BVC firing rates sometimes go over max_fr. I believe this is due to how self.cell_fr_norm is computed from locations uniformly sampled across the environment during initialization.

I've given it a lot of thought, and I don't think there's a straight forward perfect solution. However, I think I can propose a slight improvement. Increasing the resolution by making dx smaller is not ideal, as memory usage climbs quickly. So, I propose using the uniformly sampled locations, but also adding jittered locations.

For clarity, I've separated the initialization from the init into its own function, but everything here is the same as before, except the lines in if add_jittered:. Basically, in addition to using the uniformly locs to estimate max firing, I also jitter each location with a value between -dx/2 and dx/2 in x and y. (Using dx/2 should prevent any points going out of bounds.) These jittered locations are appended to locs and all of the values are used together to estimate max firing rate for each neuron. Importantly, this cannot make the estimates worse, as locs still includes the original uniformly sampled locs. It can only improve them in cases where the jittered locs find higher firing rates near the uniform locs. This does double the memory use, but my tests indicate it this is more than compensated for by the improvement in the estimates.

def _set_cell_fr_norm(self, dx=0.04, add_jittered=True):

    locs = self.Agent.Environment.discretise_environment(dx=dx)
    locs = locs.reshape(-1, locs.shape[-1])

    if add_jittered:
        jitter = np.random.uniform(-dx/2, dx/2, locs.shape)
        locs = np.append(locs, locs + jitter, axis=0)

    self.cell_fr_norm = np.ones(self.n) # value for initialization

    # ignores the warning raised during initialisation of the BVCs
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        _cell_fr_norm = np.max(self.get_state(evaluate_at=None, pos=locs), axis=1)
        self.cell_fr_norm = (_cell_fr_norm - self.min_fr) / (self.max_fr - self.min_fr)

From the examples below, you can see that no jitter and dx=0.04 leads to 1.5% of firingrates in this random run I did being over the max firing rate of 1. Cutting dx in two greatly improves this (down to 0.6% of firingrates being above 1), but requires 4x more datapoints and in my tests using way too much memory. Using jitter and a higher dx (0.05) requires only a few more points and performs quite well (0.63% of firingrates above max_fr), though the max firingrate is higher than with dx=0.02. I've also included an example with dx=0.04 and jitter for comparison purposes. Of course, this uses 2x more points than no jitter.

jitter_effect

I've run it several times to make sure the improvement isn't a fluke. It might not be worth implementing this, as it's not an actual fix, but I figured I'd propose it just in case.

A different solution would be a two-step estimate: Use locs to identify the top 6 locations with the highest firingrates for each neuron, and then sample uniformly a certain number of points around those.

colleenjg commented 3 months ago

Here's another run jitter_effect2

TomGeorge1234 commented 3 months ago

It an interesting point. I agree there's no neat solution because there isn't a theoretical/analytic answer to whats the maximum firing rate of a BVC. It depends on the exact environment.

I'm kind of happy with your solution. It depends how much you really mind that the max_fr overshoots the user desired value. I see a few options.

Option 1: Do nothing and just tell in a comments that this is only an estimate. Option 1.1: Your solution, add slightly jittered test points. Simple and easy but doesn't really fix it.

Option 2: Redefine. If you really minded a better way might be to keep it as it is and add a hard threshold. Some kind of tanh or threshold-linear function on the output of each neuron that just clips any overshoot back to max_fr (this works nicely because we can be sure its always an overshoot, not an undershoot. I don't particularly like this because the clipping could be quite bad for ego-BVCs (explained below).

Option 3: We could redefine what max_fr means so that it can be calculated analytically. For example we could change it to mean "if the environment consisted of a single infinite straight wall what would the max firing rate of the cell be in this environment?". This could probably be computed analytically then used instead of the empirical estimate. Or what about this, if a cell had preferred distance range of X cm then wouldn't it necessarily fire maximally in a the centre of a circular environment of radius X cm? This could be calculated and used as the scaling. Guaranteeably the cell must fire less or equal to this in its actual environment so it is a form of upper bound. As long as we tell the user this is what max_fr means they will get what they want.

Note that it's going to be much worse for ego-centric BVCs where the (true) max_fr requires a grid search over positions and head-directions which I think we're not currently implementing so the overshoot could be quite severe for ego-BVCs if II had to guess. Because of this, whatever we choose I'd advocate better explaining in the comments what max_fr means in the case of each vector cell as it's not such a trivial notion.

Fwiw I think I'd lean option 1 or 1.1.

colleenjg commented 3 months ago

Thank you for the detailed response. I'm not sure what the best option is, but I realize that you are quite right! The head direction issue is probably contributing perhaps more than the sampling. This makes a lot of sense. I'm going to give this more thought to think of whether I can make a different small improvement that could substantially improve the estimate!

TomGeorge1234 commented 3 months ago

Yeah but the head direction thing actually would only be an issue if you're using egocentric BVCs (params = {'reference_frame':'egocentric'} or you're using FieldOfViewBVCs etc.). Also I just test it and the overshoot isn't noticeably worse for ego-BVCs.

colleenjg commented 3 months ago

Ah, that makes sense, as that's where my student and I discovered the largest overshoot (FieldOfViewBVCs). Values up to 1.6 with max_fr=1.0.

TomGeorge1234 commented 3 months ago

Ah ok, in which case your proposed fix wouldn't really solve their problem, we'll have to think of another way...

TomGeorge1234 commented 3 months ago

Hi Colleen, where did we leave. What's the current workaround your student is using and do you think we should push ahead and implement something more proper?

colleenjg commented 3 months ago

Hey Tom, we are just setting a max on the firingrates for the particular downstream calculation we were performing. Sorry, haven't had time to look into this further, though it is on the to do list.

TomGeorge1234 commented 2 months ago

Closing for now as current situation doesn't constitute a "bug" per se, but will happily reopen whenever needed...