Closed davidzoltowski closed 11 months ago
This PR makes two small changes to work with newer versions of Jax and Flax.
ssm.py
np.array
np.DeviceArray
unfreeze()
train_helper.py
Closing this PR to stop new development on our branch from getting merged into this PR. We can open a separate branch and PR to resubmit the changes from the first commit of this PR.
This PR makes two small changes to work with newer versions of Jax and Flax.
ssm.py
is changed tonp.array
asnp.DeviceArray
is deprecatedunfreeze()
call intrain_helper.py