Closed yunlongxu-numagic closed 2 years ago
In general it seems to work ok coding wise, by adding a layer of array->dict
before the new model io and dict->array
after. However what is causing trouble is the jax jit compilation time increases significantly, especially for the hessian:
execution time is slightly increased, but not so significant.
Why does the jit compilation time increase so much for this thin layer of wrapper?!
this slowdown seems to be only limited to using docker on MacOS
So:
TODOS before merge:
get_group_names
andget_cls_group_names
apply_and_forward
,apply_and_forward_with_arrays',
forward,
forward_with_arrays`ModelReturn
andStateSpaceModelReturn
are used with both arrays and dictionaries value for the entries, need to fix thatself.get_input(inputs, "this_input")
would throw some meaningful error, butinputs["this_input"]
would only throw key error -> moved to a new issuecon_outputs
outside of model I/O? it is only needed for OCP, not really for models -> move to new issue -> moved to new issues_dot
naming convention for states_dot?