google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.22k stars 244 forks source link

Autoreset behavior #174

Open DavidSlayback opened 2 years ago

DavidSlayback commented 2 years ago

I've been digging into Brax as a potential alternative to some modified dm_control enviornments I've been using and am really loving the speedup! That said, I feel like I've run into a major issue using the environments in RL and was looking for some guidance.

Basically, my environments are all partially-observable domains built off of "ant". A lot of the conditions are randomized per-episode (e.g., ant/target starting positions). I've been using the "create_gym_env" feature to work with my PyTorch agents, but I noticed a big potential issue.

At first glance, the AutoResetWrapper seemed to do what standard gym VectorEnvs do, but in reality, it's not really "resetting" the environments (with a new seed) but instead just setting them back to a cached first state. So the randomization of start conditions I do only applies across the whole batch of environments, and then for the entire training process, each individual environment is the same as it was before.

Is there a way to actually reset individual environments within a batch?

cdfreeman-google commented 2 years ago

We haven't actually encountered environments that needed more than the initial randomness you can cache into, say, 2048 environments, but your situation seems like a good test case. (see discussion here: https://github.com/google/brax/issues/167)

One sensible way of doing this is wrapping the autoresetwrapper with another wrapper that every X-number of episodes refreshes the first_qp state. But you might want to verify that you really need this extra randomness by just manually refreshing that state by calling reset

i.e., around line 105 here, you could instead do:

maybe_reset = self.reset(rng)
qp = jp.tree_map(where_done, maybe_reset['first_qp'], state.qp)

This will be slower (by a lot probably, because every step will be calling reset), but it should reveal whether the randomness is the real bottleneck

DavidSlayback commented 2 years ago

Yeah, 2048 environments have a lot of randomness built in, but if I'm trying to solve a reasonably-sized generalized task (like a procedurally generated maze or foraging task) instead of a specific subset of that (based on seed), I feel like individual resets make more sense. Definitely understand the speed concerns, though.

Your solution seems like a reasonable approach. Another thing I was considering was building the "reset" call into the step function of the environment itself. I lose a bit of flexibility, but then I don't have to worry about external resets. With non-brax environments, I've often implemented a batched version with a "_reset_some(mask)" function.

I appreciate the advice! One last question I had was about best practices for Jax RNG. I'll use my "ant tag" environment as an example.

def reset(self, rng: jp.ndarray) -> env.State:
    rng, rng1, rng2 = jp.random_split(rng, 3)
    qpos = self.sys.default_angle() + jp.random_uniform(
        rng1, (self.sys.num_joint_dof,), -.1, .1)
    qvel = jp.random_uniform(rng2, (self.sys.num_joint_dof,), -.1, .1)
    ant_pos = jp.random_uniform(rng1, (2,), -self.cage_xy, self.cage_xy)
    qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
    pos = index_add(qp.pos, self.ant_mg, ant_pos[...,None])
    rng, tgt = self._random_target(rng, ant_pos)
    pos = jp.index_update(pos, self.target_idx, tgt)
    qp = qp.replace(pos=pos)
    info = self.sys.info(qp)
    obs = self._get_obs(qp, info)
    reward, done, zero = jp.zeros(3)
    metrics = {
        'hits': zero,
    }
    info = {'rng': rng}
    return env.State(qp, obs, reward, done, metrics, info)

Should I re-use rng1? And I've noticed anecdotally that the two numbers drawn my random_uniform for the ant are typically quite close, they don't seem like 2 randomly sampled numbers

Can I do this more efficiently?

cdfreeman-google commented 2 years ago

I probably wouldn't reuse rng1. It really isn't that expensive to just generate another rng seed:

rng, rng1, rng2, rng3 = jp.random_split(rng, 4)

and then use rng3 for your ant_pos randomness. The random number generator is deterministic and stateless, so two calls to jp.random_uniform(rng1, blah blah) will give you the same "random" results.

I'm not sure I understand your last question. Is it that you want to generate a random location that's guaranteed to be some distance away from the ant, but in some range from there? I'd probably just parameterize that location by distance and theta, i.e.:

random_dist = jp.random_uniform(rng3, (1,), min_dist, max_dist)
random_angle =  jp.random_uniform(rng4, (1,), 0., 2.*jp.pi)
ant_pos = jp.array([random_dist*jp.cos(random_angle), random_dist*jp.sin(random_angle)])

Is that what you had in mind?

DavidSlayback commented 2 years ago

Got it, thanks again!

For the first question, there's no weirdness with using one key to draw a vector of 2 random numbers from them same range (each is -4.5 to 4.5)? I could just be seeing weird things that would iron out if I were looking at more seeds.

For the second question, I just want to make sure that the target position is (a) within the arena boundaries and (b) at least "min_distance" from the ant. So the max distance would vary based on the ant's start position and the angle, and some angles could be entirely impossible. That's where I'm sort of stuck on resampling; ideally, I'd sample a random point once from the space of possible points after ant placement.

The actual random target sampling code is this (I added some more jumpy functions on my end, so the "while_loop" is a JIT-compatible Jax one

def _random_target(self, rng: jp.ndarray, ant_xy: jp.ndarray) -> Tuple[jp.ndarray, jp.ndarray]:
    """Returns a target location at least min_spawn_location away from ant"""
    rng, rng1 = jp.random_split(rng, 2)
    xy = jp.random_uniform(rng1, (2,), -self.cage_xy, self.cage_xy)
    minus_ant = lambda xy: xy - ant_xy
    def resample(rngxy: Tuple[jp.ndarray, jp.ndarray]) -> Tuple[jp.ndarray, jp.ndarray]:
        rng, xy = rngxy
        _, rng1 = jp.random_split(rng, 2)
        xy = jp.random_uniform(rng1, (2,), -self.cage_xy, self.cage_xy)
        return rng1, xy

    _, xy = while_loop(lambda rngxy: jp.norm(minus_ant(rngxy[1])) <= self.min_spawn_distance,
                          resample,
                          (rng1, xy))
    target_z = 0.5
    target = jp.array([*xy, target_z]).transpose()
    return rng, target
cdfreeman-google commented 2 years ago

Ahhhh, I see--totally misunderstood what you were doing.

First question: To be probably overly specific, there's no problem with this:

ant_pos = jp.random_uniform(rng1, (2,), -self.cage_xy, self.cage_xy)

Whereas

ant_pos_x = jp.random_uniform(rng1, (1,), -self.cage_xy, self.cage_xy)
ant_pos_y = jp.random_uniform(rng1, (1,), -self.cage_xy, self.cage_xy)

this will always have ant_pos_x=ant_pos_y.

For question 2: Yeah okay this is tricky. The resampling trick is probably the "simplest" in terms of "lines of code per unit how-hard-do-I-have-to-think-about-this".

Another option: sample the point (r, theta) randomly. One of (r, theta) and (r, theta+180) has a valid theta, and might just need to have its r projected to the boundary. This would slightly oversample boundary points (relative to a uniform sampling of "valid" points), but it's at least deterministic.

Another option: This does have an analytic solution, but it has a bunch of irritating edge cases (like which edges the minimum distance from the ant is in contact with). I'd probably just do the previous option, unless for some reason uniform sampling is super duper important.

Another option: Construct a grid of candidate points, compute a mask of valid points, and sample one of these. It has lower resolution than the other options, but would be uniformly sampling, and would converge to the right thing in the limit of lots of grid points.

DavidSlayback commented 2 years ago

Thank you again! I realize you've got plenty of other stuff to work on even just with Brax, appreciate the thorough answers. I'll probably go with your second option.

As to the original topic of the post, would it be useful for me to profile the different reset options so that you get an idea of whether individual resets are prohibitive in the future? Also, any use for extra jumpy functions? I saw that you're trying to add them to JKTerry's Farama repo, wasn't sure where would be the best place to contribute

cdfreeman-google commented 2 years ago

Haha of course! Happy to help.

Yes, we'd love to have some numbers on reset efficiency!

Let me check with Erik about the fate of jumpy--I'll get back to you!

cdfreeman-google commented 2 years ago

Update: Feel free to open PRs adding jumpy functions here if you're using them in Brax!

DavidSlayback commented 2 years ago

Update on the reset numbers:

https://colab.research.google.com/gist/DavidSlayback/bf5038ec024bb6e47568af2e2ba99c16/autoreset.ipynb#scrollTo=gazgx0KXWJfw

So I implemented a couple basic strategies using your built-in "fetch" environment just to keep the notebook simple: 1) Original AutoReset (same "first state" for all environments that are done) 2) "Naive" AutoReset (calls reset every timestep, replaces where done) 3) "On Terminal" AutoReset (calls reset only on timesteps where at least one environment is done) 4) "Cached" AutoReset (refresh "first state" every N steps, behaves like original otherwise)

I feel like I may not be timing these properly, though? I'm not seeing much difference in times beyond the time it takes to JIT

btnorman commented 2 years ago

I've started using Brax, and I am really enjoying it!

I just wanted to note that I've been using Brax on tasks that involve significant domain randomization, and/or curriculum development, and the auto-reset behavior tripped me up too

erwincoumans commented 2 years ago

So when using the AutoResetWrapper, there is no randomization, except for at the start? Doesn't the training overfit on those initial random seeds?

ZaberKo commented 6 months ago

Is there any new ideas? Following the design of gym wrappers (eg: AtariPreprocessing.noop_max), maybe we can manually apply several steps of random actions after the fake reset in AutoResetWrapper?

erikfrey commented 6 months ago

Brax PPO now supports a param num_resets_per_eval if you want to randomize your init states multiple times during training:

https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py#L90

We generally don't use this as there's no overfitting when num_envs is large - but perhaps you'll find that helpful.

ZaberKo commented 6 months ago

Brax PPO now supports a param num_resets_per_eval if you want to randomize your init states multiple times during training:

https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py#L90

We generally don't use this as there's no overfitting when num_envs is large - but perhaps you'll find that helpful.

As pointed by @erwincoumans, it's essential to have some init randomness of envs during the training stage. It seems that num_resets_per_eval only controls the randomness at evaluation stage.