Closed TomGeorge1234 closed 1 year ago
A helpful case study in support of args ...
Most of you (I'm sure) have seen Grant Sanderson's beautiful 3blue1brown YouTube channel. Grant impressively homebrewed the manim
package that creates his stunning math videos.
Similar to here, manim
started out using a CONFIG
dict. On a positive note, the CONFIG
dict cut down on lines in object init; encouraged people to spell out settings in one place. But on the dark side, dicts required nearly re-coding a lot of Python features handled by kwargs and setting attributes. Ultimately the community fork decided to kill CONFIG dicts in favour of args -- decision convo here: https://github.com/ManimCommunity/manim/pull/763
Grant Sanderson's fork is also trying to remove them: https://github.com/3b1b/manim/pull/1932
π― replotting = slow.
... if ratinabox caches plot objects, super recommend scheme we chatted about: https://github.com/TomGeorge1234/RatInABox/issues/30#issuecomment-1486449726
The TaskEnvironment has a weak version of this feature -- doesn't replot everything and thus renders quickly. But it's pretty hacky in my view that the environment caches things about its agents and goals. In the long-run, it will be more maintainable to have each class in charge of caching its own plot objects rather than having to change master supervisor class's plot every time the children classes change.
Especially easy-to-type variables.
Tools like jedi
and language-server-protocol
offer better code completion for type-hinted variables.
Possible suggestion: each RIB class could have a list of children (environment.childen = [agent, ...]
; agent.children=[neuron,...]
) to unify the way .update()
and .plot()
/.render()
calls cascade down a hierarchy. It may be more uniform than each object having a different attribute name for its children.
No strong opinions. Leaning partial Jax if the penalty for binary-op/shuttling numpy to a CPU jax.device is low.
Sounds like a great idea overall for the longevity of the package! I definitely agree for the args instead of dicts, type hinting and unit testing. For global environment, if the cascading update is implemented, I would suggest having a kwarg
like cascade=True
, to allow users to opt out, when needed. No strong views on the other sections.
I would suggest an additional section:
modularity: Many of the classes have very long methods that chain a lot of complex, separate computations together. When I've created new classes for my own use, e.g., new Agent
classes, I've had to copy long sections of certain methods that I needed to overwrite, but only partially (for example, for computing an agent's velocity). This can create a lot of code duplication (I think there may already be some for the plotting methods). So, I strongly recommend adding the goal of modularization to the list, i.e., extracting meaningful subparts of class methods and turning them into their own functions, perhaps aggregated into agent_util.py
, env_util.py
and neuron_util.py
, or something like that.
Great comments, thanks guys. @SynapticSage 3B1B advice heeded! @colleenjg you're right this could be more modular, for example Agent.update()
is pretty enormous. Breaking these down would make sense so I'll look to do that. Don't expect this anytime soon btw so any new ideas, keep posting them here.
These all sound like great changes for RAIB 2.0, and I agree w all of the comments from @SynapticSage and @colleenjg :)
I'm a particularly big fan of the global environment updating, as this seems much more concise. My only concern is whether this would slow down updates for really long simulations (like the ones I have been running, e.g. @ 30 Hz x 31 sessions x 40 min/session). It might be ideal to perform more selective updates and skip others if they are going to be static using some sort of argument in update()
?
As far as Jax compatibility, I would be very much for this if it can actually speed things up for the heavier computations and long simulations, but as you point out it might not save compute time if large arrays are being converted often. I believe it would be worth some case testing in a couple of large simulations before ruling this out.
Thanks for the feedback, closing for now.
One thing that just occurred to me, which could be considered:
Only passing ax
to the plotting functions, not fig
.
In typical use cases, to my knowledge, passing both should be redundant, as you can access the figure with ax.figure
(or ax.ravel()[0].figure
in cases where ax
is an array).
Agreed and added to the list. It's essentially redundant and only add bloat
If you add jax
support, would it be possible to do it through an optional extra for opportunistic speed-ups rather than as a mandatory dependency?
As the primary maintainer of the Fedora Linux package for this project, Iβm not sure if packaging https://github.com/google/jax would be feasible for us or not. While it does look like jax
can be built without support for the proprietary CUDA SDK, itβs still a pretty gnarly stack when taken together with https://github.com/openxla/xla, and Iβm not sure whether or not an attempt to package it would end up hitting a hard requirement on something nonfree.
@musicinmybrain thanks for your feedback - that's ok, I doubt we'd go full jax
. In fact leaning towards no jax at all actually. After some preliminary testing seems like getting significant speed ups would be difficult as most of the heavy computations are already vectorised
I've begun to think about 2.0. The reason is that there are are certainly a couple of choices I made early on in development which weren't optimal. Now could be a good time to fix these as the community is growing but still small enough it won't be super disruptive. Also fixing them will make it easier to maintain RiaB in the long run.
I'm opening this issue to get community thoughts on this. @SynapticSage @colleenjg @jquinnlee @mehulrastogi you're some of the most active users I know fairly well so I'm tagging you to get your input (if you have any), but anyone can chip in here. Here's my thoughts:
Essential and backwards incompatible changes (do first):
Neurons
classes in one.py
file.update()
: Given, now,Environments
know about theirAgents
andAgents
know about theirNeurons
we could have just one update function inEnv
which cascades through else thing else. Cleaner?dev
-->main
Environment
stores the global clock. This just makes sense imo.drift_velocity
kwarg. Maybe insteadAgent
s can have apolicy()
method which returns a drift - this would default to the random motion policy, unifying that too. Just something to consider.Other essential changes
update()
perhaps adding into new agent/neuron/env specific utils scripts.Env.history
dictionary. Then, when plotting / animating the environment we can pass in a time argument and the correct state can be retrieved and plotted. The state of the environment only appends to history whenever it changes (e.g. a setter is called).plot_environment()
it can be passed afig
anax
and a new object which is a list/dict of plot objects,R
which are allmatplotlib.Artists
already existing on the figure. The environment can store an equivalent list of plot objects and whenever this changes (e.g. a wall is added or an object is moved etc.) this change is logged then plotting can (i) get the list of plot objects corresponding to the correct time and (ii) compared it to the passed list, if they aren't equal then repot the env, otherwise don't bother. Something like that.Environment
s have anEnv.history
dictionary storing the full "state" of the environment (all object locations, walls, boundaries, etc.). ThenEnv.plot_environment()
takes a time argument and find the state of the at that time and plots that.ax
notfig
to figure plotting functions. This may throw up some things but likely minor.utils.py
into separate ones for theAgent
package,Neurons
package andEnv
package and maybe also amisc
.RatInABox/RatInABox
notTomGeorge1234/RatInABox
RatInABox/RatInABox_RL
** package containing all the RL stuff (Actor
,Critic
,ValueNeuron
,TDError
,TaskEnv
etc.)IntermediateNeurons
subclass for neurons which aren't "fundamental" but take other neurons as inputs. Current examples areFeedForwardLayer
andNeuralNetworkNeurons
DynamicNeurons
subclass for neurons which aren't static i.e. you can't callNeurons.plot_rate_map()
because they actually depend on the past history. Examples includeTDErrorNeurons
(to be made) or anything with recurrency.SmoothRandomFeatureNeurons
just some spatially tuned but random neurons. Users just provide a length scale. Would be useful for a lot of feature learning studies. Probably something like a gaussian process underlying these neurons.Things to consider
Neurons
should followtorch.nn.module
API - this would make more efficient the evaluation of complex feedforward graphs which currently happens in a backwards manner. This might require renaming the.get_state()
method with.forward()
. Need to think more about thisnp
-->jnp
everywhere.I'm not a software guy so @SynapticSage @mehulrastogi feel free to give high level comments about best way to go forward.