BasisResearch / chirho

An experimental language for causal reasoning
https://basisresearch.github.io/chirho/getting_started.html
Apache License 2.0
164 stars 10 forks source link

simplify `AutoSoftConditioningReparam` to use `soft_eq` directly. #534

Closed rfl-urbaniak closed 3 months ago

rfl-urbaniak commented 3 months ago

This is a follow-up to PR #500 .

The strategy:

  1. Given a constraint and a scale/alpha passed as arguments to AutoSoftConditioning define _soft_eq that uses these values, and pass it on to KernelSoftConditionReparam, with proper constraint diagnosis preceding this call.

  2. Eliminate TorchKernel, SoftEqKernel and RBFKernel completely.

  3. Simplify AutoSoftConditioning accordingly.

Remark. I'm not sure if:

scale = self.scale * functools.reduce(
                 operator.mul, msg["fn"].event_shape, 1.0
             )

is still needed, commented out for now.

  1. Modify tests/observational/test_handlers.py accordingly, to avoid using these kernels.
rfl-urbaniak commented 3 months ago

This is sufficient for tests/observational/test_handlers.py to pass, but tests/dynamical/test_handler_composition.py fails.

  1. Reinstated
scale = scale * functools.reduce(
               operator.mul, msg["fn"].event_shape, 1.0
            )

still, same failure.

  1. Considered reverting to the intermediate variant that preserves TorchKernel as a torch.nn.Module, same failure.

@SamWitty let's chat about this at some point.