Open eadadi opened 1 month ago
the following linter test is incorrect. nameclass is uppercase
flashbax/buffers/mixer.py:22:2: N813 camelcase 'Partial' imported as lowercase 'partial'
flashbax/buffers/prioritised_trajectory_buffer.py:27:2: N813 camelcase 'Partial' imported as lowercase 'partial'
flashbax/buffers/trajectory_buffer.py:25:2: N813 camelcase 'Partial' imported as lowercase 'partial'
flashbax/buffers/trajectory_queue.py:19:2: N813 camelcase 'Partial' imported as lowercase 'partial'
I'm not a maintainer, but that linter seems correct to me. The original function name is CameCase, you're importing it as lowercase.
I'm confused with this one, is there a reason you'd want to put buffer
on an accelerator, do you see speed ups over just putting the buffer state on the accelerator? Because buffer
should just be a collection of functions?
Updated various buffer files to replace functools.partial with jax.tree_util.Partial for consistency and improved functionality.
The motivation is to be able to use jax transformations over buffers.
For example, before this patch, this wasn't working: