google / trajax

Apache License 2.0
186 stars 23 forks source link

Use of deprecated Jax APIs/behavior #8

Open pfrommerd opened 1 year ago

pfrommerd commented 1 year ago

Trajax uses deprecated Jax APIs/behavior which result in warnings being emitted in two locations.

  1. The first instance is at https://github.com/google/trajax/blob/a14d248294e9c8f74f10a10e787af24ac6be1ad7/trajax/optimizers.py#L755 where argnum 9 is specified although the function only has 8 arguments. This results in the warning:
    jax/_src/api_util.py:165: SyntaxWarning: Jitted function has static_argnums=(0, 1, 9), but only accepts 8 positional arguments. 
    This warning will be replaced by an error after 2022-08-20 at the earliest.
  2. The second warning is at https://github.com/google/trajax/blob/a14d248294e9c8f74f10a10e787af24ac6be1ad7/trajax/tvlqr.py#L98 which raises the following warning
    trajax/tvlqr.py:98: FutureWarning: The sym_pos argument to solve() is deprecated and will be removed in a future JAX release. Use assume_a='pos' instead.
J4nn1K commented 1 year ago

The second warning should be fixed in #10.