google-deepmind / chex

https://chex.readthedocs.io
Apache License 2.0
759 stars 44 forks source link

chex.Dimensions API enhancement #231

Open wbrenton opened 1 year ago

wbrenton commented 1 year ago

I would like to propose an API enhancement that allow the use of chex.Dimensions inside function annotations. If there is interest I'd like to contribute. Example below:

dims = chex.Dimensions(B=batch_size, T=sequence_len, E=embedding_dim)
...
def foo(arr: chex.Array):
     chex.assert_shape(arr, dims['BTE'])
     # fn logic

### turns into ###

def foo(arr: chex.Array(dims['BTE'])): # behind the scenes assert on function call
     # fn logic

This is particularly useful for dataclasses e.g.

dims = chex.Dimensions(B=batch_size, T=rollout_len)

# asserts are run on instantiation
class TimeStep:
     q_values: chex.Array(dims['BT']) 
     discounts: chex.Array(dims['BT']) 
     rewards: chex.Array(dims['BT']) 

Pros:

Cons:

KristianHolsheimer commented 1 year ago

Thanks for your interest in chex!

This suggestion is very interesting. Many of us working with arrays in python on a daily basis are eagerly awaiting PEP 646, which was accepted into python version 3.11.

Once python 3.11 becomes more mainstream we will definitely consider incorporating shape annotations into chex. And perhaps we could augment or fork chex.Dimensions to return TypeVarTuples, along the lines of your suggestion.

For the time being, however, we will not implement such a change. In particular, mixing runtime checks with static type annotation is out of scope, at least for now.

P.S. If you're interested in doing type annotation at runtime, you might find the pydantic project useful: https://docs.pydantic.dev/