instadeepai / flashbax

⚡ Flashbax: Accelerated Replay Buffers in JAX
https://instadeepai.github.io/flashbax/
Apache License 2.0
213 stars 11 forks source link

Speed up vault tests #42

Open mickvangelderen opened 1 week ago

mickvangelderen commented 1 week ago

The tests in vault_test.py take a considerable amount of time to run. Perhaps the parameters can be tuned to make the tests run faster without losing confidence that the code works.

sash-a commented 1 week ago

Totally agreed @mickvangelderen. Would you be able to make a PR for this? :smile:

I think here:

@pytest.fixture()
def max_length() -> int:
    return 256

@pytest.fixture()
def fake_transition() -> FbxTransition:
    return FbxTransition(
        obs=CustomObservation(
            x=jnp.ones(shape=(1, 2, 3), dtype=jnp.float32),
            y=jnp.ones(shape=(4, 5, 6), dtype=jnp.float32),
        ),
        act=jnp.ones(shape=(7, 8), dtype=jnp.float32),
    )

We can drop the max length to 8 and make the shapes much smaller maybe all of them can be (1,1,2)? I don't think it will save a huge amount of time, but maybe a minute?

mickvangelderen commented 1 week ago

Thanks for the suggestion—it’s similar to what I implemented locally, and it does work well! I’m happy to open a PR for it, but I currently have three other PRs that I'd like to prioritize discussing first, if possible:

Thank you!