BasisResearch / chirho

An experimental language for causal reasoning
https://basisresearch.github.io/chirho/getting_started.html
Apache License 2.0
172 stars 12 forks source link

Replace bespoke flattening logic with PyTree-based flattening in `chirho.robust` CG solver #551

Closed eb8680 closed 4 months ago

eb8680 commented 4 months ago

This refactoring PR replaces the bespoke functools.singledispatch-based flattening/unflattening logic in chirho.robust.internals.utils.make_flatten_unflatten with equivalent logic using PyTorch's PyTree flatten/unflatten implementations, which are already used elsewhere in chirho.robust.

This should not change the behavior of make_flatten_unflatten and should be exercised by existing unit tests.