RatInABox-Lab / RatInABox

A python package for modelling locomotion in complex environments and spatially/velocity selective cell activity.
MIT License
168 stars 31 forks source link

RatInABox2.0 - Opening the discussion #60

Closed TomGeorge1234 closed 8 months ago

TomGeorge1234 commented 1 year ago

I've begun to think about 2.0. The reason is that there are are certainly a couple of choices I made early on in development which weren't optimal. Now could be a good time to fix these as the community is growing but still small enough it won't be super disruptive. Also fixing them will make it easier to maintain RiaB in the long run.

I'm opening this issue to get community thoughts on this. @SynapticSage @colleenjg @jquinnlee @mehulrastogi you're some of the most active users I know fairly well so I'm tagging you to get your input (if you have any), but anyone can chip in here. Here's my thoughts:

Essential and backwards incompatible changes (do first):

Other essential changes

Things to consider

I'm not a software guy so @SynapticSage @mehulrastogi feel free to give high level comments about best way to go forward.

SynapticSage commented 1 year ago

args, not dicts πŸ‘

A helpful case study in support of args ...

Most of you (I'm sure) have seen Grant Sanderson's beautiful 3blue1brown YouTube channel. Grant impressively homebrewed the manim package that creates his stunning math videos.

Similar to here, manim started out using a CONFIG dict. On a positive note, the CONFIG dict cut down on lines in object init; encouraged people to spell out settings in one place. But on the dark side, dicts required nearly re-coding a lot of Python features handled by kwargs and setting attributes. Ultimately the community fork decided to kill CONFIG dicts in favour of args -- decision convo here: https://github.com/ManimCommunity/manim/pull/763

Grant Sanderson's fork is also trying to remove them: https://github.com/3b1b/manim/pull/1932

plotting πŸ‘

πŸ’― replotting = slow.

... if ratinabox caches plot objects, super recommend scheme we chatted about: https://github.com/TomGeorge1234/RatInABox/issues/30#issuecomment-1486449726

The TaskEnvironment has a weak version of this feature -- doesn't replot everything and thus renders quickly. But it's pretty hacky in my view that the environment caches things about its agents and goals. In the long-run, it will be more maintainable to have each class in charge of caching its own plot objects rather than having to change master supervisor class's plot every time the children classes change.

type hinting πŸ‘

Especially easy-to-type variables.

Tools like jedi and language-server-protocol offer better code completion for type-hinted variables.

unit testing πŸ‘

global environment πŸ‘

Possible suggestion: each RIB class could have a list of children (environment.childen = [agent, ...]; agent.children=[neuron,...]) to unify the way .update() and .plot()/.render() calls cascade down a hierarchy. It may be more uniform than each object having a different attribute name for its children.

Jax πŸ€·β€β™‚οΈ

No strong opinions. Leaning partial Jax if the penalty for binary-op/shuttling numpy to a CPU jax.device is low.

colleenjg commented 1 year ago

Sounds like a great idea overall for the longevity of the package! I definitely agree for the args instead of dicts, type hinting and unit testing. For global environment, if the cascading update is implemented, I would suggest having a kwarg like cascade=True, to allow users to opt out, when needed. No strong views on the other sections.

I would suggest an additional section: modularity: Many of the classes have very long methods that chain a lot of complex, separate computations together. When I've created new classes for my own use, e.g., new Agent classes, I've had to copy long sections of certain methods that I needed to overwrite, but only partially (for example, for computing an agent's velocity). This can create a lot of code duplication (I think there may already be some for the plotting methods). So, I strongly recommend adding the goal of modularization to the list, i.e., extracting meaningful subparts of class methods and turning them into their own functions, perhaps aggregated into agent_util.py, env_util.py and neuron_util.py, or something like that.

TomGeorge1234 commented 1 year ago

Great comments, thanks guys. @SynapticSage 3B1B advice heeded! @colleenjg you're right this could be more modular, for example Agent.update() is pretty enormous. Breaking these down would make sense so I'll look to do that. Don't expect this anytime soon btw so any new ideas, keep posting them here.

jquinnlee commented 1 year ago

These all sound like great changes for RAIB 2.0, and I agree w all of the comments from @SynapticSage and @colleenjg :)

I'm a particularly big fan of the global environment updating, as this seems much more concise. My only concern is whether this would slow down updates for really long simulations (like the ones I have been running, e.g. @ 30 Hz x 31 sessions x 40 min/session). It might be ideal to perform more selective updates and skip others if they are going to be static using some sort of argument in update()?

As far as Jax compatibility, I would be very much for this if it can actually speed things up for the heavier computations and long simulations, but as you point out it might not save compute time if large arrays are being converted often. I believe it would be worth some case testing in a couple of large simulations before ruling this out.

TomGeorge1234 commented 11 months ago

Thanks for the feedback, closing for now.

colleenjg commented 10 months ago

One thing that just occurred to me, which could be considered:

Only passing ax to the plotting functions, not fig.

In typical use cases, to my knowledge, passing both should be redundant, as you can access the figure with ax.figure (or ax.ravel()[0].figure in cases where ax is an array).

TomGeorge1234 commented 10 months ago

Agreed and added to the list. It's essentially redundant and only add bloat

musicinmybrain commented 10 months ago

If you add jax support, would it be possible to do it through an optional extra for opportunistic speed-ups rather than as a mandatory dependency?

As the primary maintainer of the Fedora Linux package for this project, I’m not sure if packaging https://github.com/google/jax would be feasible for us or not. While it does look like jax can be built without support for the proprietary CUDA SDK, it’s still a pretty gnarly stack when taken together with https://github.com/openxla/xla, and I’m not sure whether or not an attempt to package it would end up hitting a hard requirement on something nonfree.

TomGeorge1234 commented 10 months ago

@musicinmybrain thanks for your feedback - that's ok, I doubt we'd go full jax. In fact leaning towards no jax at all actually. After some preliminary testing seems like getting significant speed ups would be difficult as most of the heavy computations are already vectorised