instadeepai / flashbax

⚡ Flashbax: Accelerated Replay Buffers in JAX
https://instadeepai.github.io/flashbax/
Apache License 2.0
209 stars 10 forks source link

Replace functools.partial with jax.tree_util.Partial #39

Open eadadi opened 1 month ago

eadadi commented 1 month ago

Updated various buffer files to replace functools.partial with jax.tree_util.Partial for consistency and improved functionality.

The motivation is to be able to use jax transformations over buffers.

For example, before this patch, this wasn't working:

    buffer = fbx.make_trajectory_buffer(**cfg)
    buffer = jax.device_put(buffer, jax.devices("cpu")[0])
CLAassistant commented 1 month ago

CLA assistant check
All committers have signed the CLA.

eadadi commented 1 month ago

the following linter test is incorrect. nameclass is uppercase

flashbax/buffers/mixer.py:22:2: N813 camelcase 'Partial' imported as lowercase 'partial'
flashbax/buffers/prioritised_trajectory_buffer.py:27:2: N813 camelcase 'Partial' imported as lowercase 'partial'
flashbax/buffers/trajectory_buffer.py:25:2: N813 camelcase 'Partial' imported as lowercase 'partial'
flashbax/buffers/trajectory_queue.py:19:2: N813 camelcase 'Partial' imported as lowercase 'partial'
garymm commented 1 month ago

I'm not a maintainer, but that linter seems correct to me. The original function name is CameCase, you're importing it as lowercase.

sash-a commented 6 days ago

I'm confused with this one, is there a reason you'd want to put buffer on an accelerator, do you see speed ups over just putting the buffer state on the accelerator? Because buffer should just be a collection of functions?