LouisDesdoigts / zodiax

Object-oriented Jax framework extending Equinox for scientific programming
https://louisdesdoigts.github.io/zodiax/
BSD 3-Clause "New" or "Revised" License
11 stars 1 forks source link

Setting `None` requires wrapping in a list #2

Closed LouisDesdoigts closed 1 year ago

LouisDesdoigts commented 1 year ago

Setting parameters to None results in a ValueError.

Minimal Example:

import zodiax as zdx
class Foo(zdx.Base):
    param : float = 1.

foo = Foo()

# Works
bar = foo.set(['param'], [None])

# ValueError
bar = foo.set('param', None)

Stack Trace:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [10], in <cell line: 7>()
      5 foo = Foo()
      6 bar = foo.set(['param'], [None])
----> 7 bar = foo.set('param', None)

File ~/mambaforge/envs/dlux/lib/python3.10/site-packages/zodiax/base.py:296, in Base.set(self, paths, values, pmap)
    275 def set(self   : Pytree,
    276         paths  : Union[str, list],
    277         values : Union[Any, list],
    278         pmap   : dict = None) -> Pytree:
    279     """
    280     Set the leaves specified by paths with values.
    281 
   (...)
    294         The pytree with leaves specified by paths updated with values.
    295     """
--> 296     new_paths, new_values = self._format(paths, values, pmap)
    298     # Define 'where' function and update pytree
    299     get_leaves_fn = lambda pytree: pytree._get_leaves(new_paths)

ValueError: not enough values to unpack (expected 2, got 1)

This should be able to be fixed by checking for None as the values input, and automatically wrapping it in a list.