RatInABox-Lab / RatInABox

A python package for modelling locomotion in complex environments and spatially/velocity selective cell activity.
MIT License
172 stars 31 forks source link

Avoid infinite recursion in plot_rate_map() #95

Closed colleenjg closed 7 months ago

colleenjg commented 7 months ago

Enabled recurrent inputs to be identified when adding inputs to FeedForward Neurons. These inputs will be ignored when a groundtruth rate map is plotted, to avoid infinite recursion.

TomGeorge1234 commented 7 months ago

Oh...I love this!!!!

Do you think it would be cleaner if, instead get_state() instead received a max_recursion_depth argument defaulting to 1. Then the internal call does something like:

if recursion_depth > 0 and inputlayer['recurrent']:
    continue
if evaluate_at == "last": 
    I = inputlayer["layer"].firingrate
else: #kick the can down the road, except for expired recurrent loops
    w = inputlayer["w"]
    I = inputlayer["layer"].get_state(evaluate_at, recursion_depth=recursion_depth-1, **kwargs)
    V += np.matmul(w, I)

Then you wouldn't even need a new plot_rate_map() function. And users could control the recursion depth if they wanted to by FFL.plot_rate_map(recursion_depth=42).

This has the benefits that recursion can never be infinite. It is slightly different from yours though because here the recursive loop is still used exactly once for rate maps. We could either avoid this by (i) setting recursion_depth=0 as default (then we'd need to put the evaluate_at="last" outside the if test) (ii) write a (now even simpler) plot_rate_map() wrapper which is just super.plot_rate_map(recursion_depth=0) or (iii) write an new update() wrapper which forces recursion >0 for online update. This last on feels a bit weird imo but would work.

What do you think - you've probably thought about this a lot more than I have.

colleenjg commented 7 months ago

That's even better! Because my version really prevents any contribution from recursive inputs from shaping the rate maps.

I think what you're proposing looks like a clean solution! The only oddity I noticed was that recursion_depth doesn't quite describe what the variable does, in my view, since this variable is decremented even if you don't have exact recursion (i.e., layer calling itself). What I mean is that if you have a loop with three nodes, each node in the loop will decrement the depth by 1, instead of one pass through the full loop decreasing it by 1.

So perhaps we can call it max_depth_if_recursion. Or is that an annoyingly long name? And I would suggest to only this decrement this variable it if the inputlayer is recurrent.

pass_max_depth_if_recursion = max_depth_if_recursion
if inputlayer['recurrent']:
    if max_depth_if_recursion == 0
        continue
    else:
        pass_max_depth_if_recursion = max_depth_if_recursion - 1

if evaluate_at == "last": 
    I = inputlayer["layer"].firingrate
else: #kick the can down the road, except for expired recurrent loops
    w = inputlayer["w"]
    I = inputlayer["layer"].get_state(evaluate_at, max_depth_if_recursion=pass_max_depth_if_recursion, **kwargs)
    V += np.matmul(w, I)

Do you think this would work?

colleenjg commented 7 months ago

I'm realizing I didn't follow the end of your comment. What if we put the default as max_depth_if_recursion=None, and so it's ignored, unless it's set, and we add plot_rate_maps(max_depth_if_recursion=None), as you suggest? I'll push a new version to clarify.

TomGeorge1234 commented 7 months ago

My mistake, definitely should only be decremented once per loop, good spot.

You made me realise there's an important distinction between recursion (get_state() calling another get_state(), which strictly applies to all inputs which are FeedForwardLayers) and recurrence (inputs which eventually circle back on themselves). The correct name is something like max_recursion_depth_for_recurrent_inputs but that's ridiculous. What about max_recurrence (which clarifies it's the variable which applies to recurrent loops in the graph structure).

`max_recurrence`: 1 # The maximum number of time get_state() recursively calls recurrent inputs (prevents recursion error when plotting rate maps).  

Perhaps more readable is:

# Skip this input if you're past the recurrence limit
if inputlayer['recurrent'] and max_depth_if_recursion>0:
    continue

# Get layer input, either from its current firingrate or from recursively calling Input.get_state(). 
if evaluate_at == "last": 
    I = inputlayer["layer"].firingrate
else: # kick the can down the road
    w = inputlayer["w"]
    I = inputlayer["layer"].get_state(evaluate_at, max_recurrence = max_recurrence-inputlayer['recurrent'], **kwargs) # decreases the recursion depth iff the layer input is flagged as recursive. 
    V += np.matmul(w, I)

Thoughts?

We should also ad a comment into add_input() clarifying that only one node in the recursive loop must be flagged as recursive.

colleenjg commented 7 months ago

Yeah, that makes sense!

if inputlayer['recurrent'] and max_depth_if_recursion > 0:
    continue

should be

if inputlayer['recurrent'] and max_depth_if_recursion <= 0:
    continue

Are you ok with the default being None, in which case this is ignored? Kind of forcing the user to reflect on what depth they want or to get a recursion error?

TomGeorge1234 commented 7 months ago

Sure (we can always change the default down the line if we change our mind)...but how will it be ignored? None-1 will throw an error.

colleenjg commented 7 months ago

Yeah, my version has a bit more lines, where None is checked

colleenjg commented 7 months ago

Ok, I just pushed a new version that seems to be working, on my end!

Edit: Sorry for the numerous force pushes, I kept finding typos. This should work now.

TomGeorge1234 commented 7 months ago

v1.11.2 fixes this