liesel-devs / liesel

A probabilistic programming framework
https://liesel-project.org
MIT License
38 stars 2 forks source link

Implement the NamedTupleInterface #150

Closed wiep closed 8 months ago

wiep commented 8 months ago

I want to use NamedTuples in the workshop but we do not have a NamedTupleInterface. The implementation is pretty straightforward and could be included in Liesel's next release. It needs to be documented and tested.

class NamedTupleInterface:

    def __init__(self, log_prob_fn: Callable[[NamedTuple], float]):
        self._log_prob_fn = log_prob_fn

    def extract_position(
        self, position_keys: Sequence[str], model_state: NamedTuple
    ):
        return {key: getattr(model_state, key) for key in position_keys}

    def log_prob(self, model_state: NamedTuple) -> float:
        return self._log_prob_fn(model_state)

    def update_state(self, position, model_state: NamedTuple):
        new_state = model_state._replace(**position)
        return new_state
jobrachem commented 8 months ago

Just out of curiosity: Why not use a dataclass or dict?

wiep commented 8 months ago

Dataclasses aren't pytrees and I want to keep the code simple. I could use a dict thats true. I didn't because a namedtuple protects you from misspelling the keys which a dict wouldn't in the dict. I guess it is easy to change in case we want to do that.