This refactoring PR replaces the bespoke functools.singledispatch-based flattening/unflattening logic in chirho.robust.internals.utils.make_flatten_unflatten with equivalent logic using PyTorch's PyTree flatten/unflatten implementations, which are already used elsewhere in chirho.robust.
This should not change the behavior of make_flatten_unflatten and should be exercised by existing unit tests.
This refactoring PR replaces the bespoke
functools.singledispatch
-based flattening/unflattening logic inchirho.robust.internals.utils.make_flatten_unflatten
with equivalent logic using PyTorch's PyTreeflatten
/unflatten
implementations, which are already used elsewhere inchirho.robust
.This should not change the behavior of
make_flatten_unflatten
and should be exercised by existing unit tests.