jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.21k stars 2.77k forks source link

Deduplicate all constants when forming branch jaxprs for lax.cond #16455

Open sharadmv opened 1 year ago

sharadmv commented 1 year ago

Context: PR #16445 added logic for deduping Refs that are shared across branches of a cond jaxpr. Constant values, on the other hand, can be duplicated (since it's safe to duplicate a value). However, there's no inherent reason to pass in the same value in different operand slots if we know they're actually the same const.

davisyoshida commented 1 year ago

Would this also have relevance to the cond and body jaxprs for the while primitive? Currently the consts for the cond and body are separated, but I don't know if there's some upstream deduplication which might move a shared const into the invars group.

sharadmv commented 1 year ago

Yes we would need to do this for correctness when handling state in a while loop. It could possibly be an optimization for values as well. Cc: @mattjj