Open sharadmv opened 1 year ago
Would this also have relevance to the cond and body jaxprs for the while primitive? Currently the consts for the cond and body are separated, but I don't know if there's some upstream deduplication which might move a shared const into the invars group.
Yes we would need to do this for correctness when handling state in a while loop. It could possibly be an optimization for values as well. Cc: @mattjj
Context: PR #16445 added logic for deduping Refs that are shared across branches of a cond jaxpr. Constant values, on the other hand, can be duplicated (since it's safe to duplicate a value). However, there's no inherent reason to pass in the same value in different operand slots if we know they're actually the same const.